Installing required LibrariesΒΆ
pip install keras torch torchvision seaborn tensorflow
Requirement already satisfied: keras in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (3.12.0) Requirement already satisfied: torch in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.8.0+cu128) Requirement already satisfied: torchvision in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (0.23.0+cu128) Requirement already satisfied: seaborn in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (0.13.2) Requirement already satisfied: tensorflow in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (2.20.0) Requirement already satisfied: absl-py in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (2.3.1) Requirement already satisfied: numpy in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (1.26.4) Requirement already satisfied: rich in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (14.2.0) Requirement already satisfied: namex in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (0.1.0) Requirement already satisfied: h5py in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (3.15.1) Requirement already satisfied: optree in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (0.17.0) Requirement already satisfied: ml-dtypes in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (0.5.3) Requirement already satisfied: packaging in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from keras) (25.0) Requirement already satisfied: filelock in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.20.0) Requirement already satisfied: typing-extensions>=4.10.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (4.15.0) Requirement already satisfied: setuptools in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (80.9.0) Requirement already satisfied: sympy>=1.13.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (1.14.0) Requirement already satisfied: networkx in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.5) Requirement already satisfied: jinja2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.1.6) Requirement already satisfied: fsspec in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (2025.9.0) Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.93) Requirement already satisfied: nvidia-cuda-runtime-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.90) Requirement already satisfied: nvidia-cuda-cupti-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.90) Requirement already satisfied: nvidia-cudnn-cu12==9.10.2.21 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (9.10.2.21) Requirement already satisfied: nvidia-cublas-cu12==12.8.4.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.4.1) Requirement already satisfied: nvidia-cufft-cu12==11.3.3.83 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (11.3.3.83) Requirement already satisfied: nvidia-curand-cu12==10.3.9.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (10.3.9.90) Requirement already satisfied: nvidia-cusolver-cu12==11.7.3.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (11.7.3.90) Requirement already satisfied: nvidia-cusparse-cu12==12.5.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.5.8.93) Requirement already satisfied: nvidia-cusparselt-cu12==0.7.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (0.7.1) Requirement already satisfied: nvidia-nccl-cu12==2.27.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (2.27.3) Requirement already satisfied: nvidia-nvtx-cu12==12.8.90 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.90) Requirement already satisfied: nvidia-nvjitlink-cu12==12.8.93 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (12.8.93) Requirement already satisfied: nvidia-cufile-cu12==1.13.1.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (1.13.1.3) Requirement already satisfied: triton==3.4.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torch) (3.4.0) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from torchvision) (12.0.0) Requirement already satisfied: pandas>=1.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from seaborn) (2.1.4) Requirement already satisfied: matplotlib!=3.6.1,>=3.4 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from seaborn) (3.8.2) Requirement already satisfied: astunparse>=1.6.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (1.6.3) Requirement already satisfied: flatbuffers>=24.3.25 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (25.9.23) Requirement already satisfied: gast!=0.5.0,!=0.5.1,!=0.5.2,>=0.2.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (0.6.0) Requirement already satisfied: google_pasta>=0.1.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (0.2.0) Requirement already satisfied: libclang>=13.0.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (18.1.1) Requirement already satisfied: opt_einsum>=2.3.2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (3.4.0) Requirement already satisfied: protobuf>=5.28.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (6.33.0) Requirement already satisfied: requests<3,>=2.21.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (2.32.5) Requirement already satisfied: six>=1.12.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (1.17.0) Requirement already satisfied: termcolor>=1.1.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (3.2.0) Requirement already satisfied: wrapt>=1.11.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (2.0.0) Requirement already satisfied: grpcio<2.0,>=1.24.3 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (1.76.0) Requirement already satisfied: tensorboard~=2.20.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorflow) (2.20.0) Requirement already satisfied: charset_normalizer<4,>=2 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.4.4) Requirement already satisfied: idna<4,>=2.5 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (3.11) Requirement already satisfied: urllib3<3,>=1.21.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2.5.0) Requirement already satisfied: certifi>=2017.4.17 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from requests<3,>=2.21.0->tensorflow) (2025.10.5) Requirement already satisfied: markdown>=2.6.8 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (3.9) Requirement already satisfied: tensorboard-data-server<0.8.0,>=0.7.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (0.7.2) Requirement already satisfied: werkzeug>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from tensorboard~=2.20.0->tensorflow) (3.1.3) Requirement already satisfied: wheel<1.0,>=0.23.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from astunparse>=1.6.0->tensorflow) (0.45.1) Requirement already satisfied: contourpy>=1.0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.3.3) Requirement already satisfied: cycler>=0.10 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (0.12.1) Requirement already satisfied: fonttools>=4.22.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (4.60.1) Requirement already satisfied: kiwisolver>=1.3.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (1.4.9) Requirement already satisfied: pyparsing>=2.3.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (3.2.5) Requirement already satisfied: python-dateutil>=2.7 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from matplotlib!=3.6.1,>=3.4->seaborn) (2.9.0.post0) Requirement already satisfied: pytz>=2020.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2025.2) Requirement already satisfied: tzdata>=2022.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from pandas>=1.2->seaborn) (2025.2) Requirement already satisfied: mpmath<1.4,>=1.1.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from sympy>=1.13.3->torch) (1.3.0) Requirement already satisfied: MarkupSafe>=2.1.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from werkzeug>=1.0.1->tensorboard~=2.20.0->tensorflow) (3.0.3) Requirement already satisfied: markdown-it-py>=2.2.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from rich->keras) (4.0.0) Requirement already satisfied: pygments<3.0.0,>=2.13.0 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from rich->keras) (2.19.2) Requirement already satisfied: mdurl~=0.1 in /home/zeus/miniconda3/envs/cloudspace/lib/python3.12/site-packages (from markdown-it-py>=2.2.0->rich->keras) (0.1.2) Note: you may need to restart the kernel to use updated packages.
Importing librariesΒΆ
Note: Training was done on lightning.ai for better compute speed
import time
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from keras.datasets import mnist, fashion_mnist
from sklearn.metrics import accuracy_score, f1_score, confusion_matrix
from sklearn.ensemble import RandomForestClassifier
from sklearn.linear_model import LogisticRegression
from sklearn.manifold import TSNE
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.utils.data import DataLoader, Subset
from torchvision import datasets, transforms, models
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)
2025-11-04 10:07:27.689711: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations. To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
Device: cuda
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device name:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU only")
CUDA available: True Device name: NVIDIA L40S
Helper functionsΒΆ
torch.manual_seed(42)
np.random.seed(42)
def show_confusion(cm, labels, title='Confusion Matrix'):
plt.figure(figsize=(7,6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', xticklabels=labels, yticklabels=labels)
plt.title(title)
plt.xlabel("Predicted")
plt.ylabel("True")
plt.show()
Importing the dataset and reshaping itΒΆ
(train_X, train_y), (test_X, test_y) = mnist.load_data()
train_X = train_X.reshape(-1,784)
test_X = test_X.reshape(-1, 784)
train_X = torch.from_numpy(train_X).float().reshape(-1, 784) / 255
train_y = torch.from_numpy(train_y).to(torch.int64)
test_X = torch.from_numpy(test_X).float().reshape(-1,784) / 255
test_y = torch.from_numpy(test_y).to(torch.int64)
Helper functionsΒΆ
def evaluate_model(model, X_test, y_test, device='cpu'):
model.eval()
with torch.no_grad():
X_test = X_test.to(device)
y_test = y_test.to(device)
outputs = model(X_test)
loss = F.cross_entropy(outputs, y_test).item()
preds = outputs.argmax(dim=1).cpu().numpy()
y_true = y_test.cpu().numpy()
acc = accuracy_score(y_true, preds)
f1 = f1_score(y_true, preds, average='macro')
cm = confusion_matrix(y_true, preds)
return preds, acc, f1, cm, loss
def summary(name, acc, f1, cm, train_losses=None):
print(f"\n{name}")
print(f"Accuracy: {acc:.4f}, F1-score: {f1:.4f}")
show_confusion(cm, list(range(10)), title=f"{name} Confusion Matrix")
if train_losses is not None:
plt.figure(figsize=(7, 5))
plt.plot(train_losses, label="Training Loss", linewidth=2)
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.title(f"Training Loss vs Epochs - {name}")
plt.legend()
plt.grid(alpha=0.3)
plt.tight_layout()
plt.show()
def visualize_tsne(model, X, y, trained=True, device='cpu'):
model.eval()
X, y = X.to(device), y.cpu().numpy()
with torch.no_grad():
x = F.relu(model.fc1(X))
layer2_out = model.fc2(x).cpu().numpy()
tsne = TSNE(n_components=2, random_state=42)
tsne_results = tsne.fit_transform(layer2_out)
plt.figure(figsize=(10, 6))
num_classes = len(np.unique(y))
for i in range(num_classes):
indices = (y == i)
plt.scatter(
tsne_results[indices, 0],
tsne_results[indices, 1],
label=i,
alpha=0.5
)
plt.legend(title="Class")
plt.title(f"t-SNE (20-neuron layer) - {'Trained' if trained else 'Untrained'} Model")
plt.tight_layout()
plt.show()
Defining the MLPΒΆ
class MLP_relu(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28*28, 30)
self.fc2 = nn.Linear(30, 20)
self.fc3 = nn.Linear(20, 10)
def forward(self, x):
x = x.view(len(x), -1)
x = F.relu(self.fc1(x))
x = F.relu(self.fc2(x))
return self.fc3(x)
class MLP_sigmoid(nn.Module):
def __init__(self):
super().__init__()
self.fc1 = nn.Linear(28*28, 30)
self.fc2 = nn.Linear(30, 20)
self.fc3 = nn.Linear(20, 10)
def forward(self, x):
x = x.view(len(x), -1)
x = F.sigmoid(self.fc1(x))
x = F.sigmoid(self.fc2(x))
return self.fc3(x)
def train_mlp(model, X_train, y_train, epochs=100, lr=0.001):
optimizer = optim.Adam(model.parameters(), lr=lr)
criterion = nn.CrossEntropyLoss()
train_losses = []
for epoch in range(epochs):
model.train()
optimizer.zero_grad()
outputs = model(X_train)
loss = criterion(outputs, y_train)
loss.backward()
optimizer.step()
train_losses.append(loss.item())
print(f"Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}")
return train_losses
Using Cross Entropy Loss and ReLU
mlp_relu = MLP_relu().to(device)
print("\nTraining MLP...")
train_losses = train_mlp(mlp_relu, train_X.to(device), train_y.to(device), epochs=1000)
preds, acc, f1, cm, test_loss = evaluate_model(mlp_relu, test_X, test_y, device)
summary("MLP on MNIST with ReLU", acc, f1, cm, train_losses)
print(f"Test Loss: {test_loss:.4f}")
Training MLP... Epoch 1/1000, Loss: 2.3110 Epoch 2/1000, Loss: 2.2992 Epoch 3/1000, Loss: 2.2862 Epoch 4/1000, Loss: 2.2712 Epoch 5/1000, Loss: 2.2543 Epoch 6/1000, Loss: 2.2358 Epoch 7/1000, Loss: 2.2162 Epoch 8/1000, Loss: 2.1962 Epoch 9/1000, Loss: 2.1763 Epoch 10/1000, Loss: 2.1564 Epoch 11/1000, Loss: 2.1361 Epoch 12/1000, Loss: 2.1151 Epoch 13/1000, Loss: 2.0936 Epoch 14/1000, Loss: 2.0715 Epoch 15/1000, Loss: 2.0489 Epoch 16/1000, Loss: 2.0259 Epoch 17/1000, Loss: 2.0023 Epoch 18/1000, Loss: 1.9784 Epoch 19/1000, Loss: 1.9540 Epoch 20/1000, Loss: 1.9293 Epoch 21/1000, Loss: 1.9043 Epoch 22/1000, Loss: 1.8790 Epoch 23/1000, Loss: 1.8533 Epoch 24/1000, Loss: 1.8274 Epoch 25/1000, Loss: 1.8012 Epoch 26/1000, Loss: 1.7747 Epoch 27/1000, Loss: 1.7479 Epoch 28/1000, Loss: 1.7210 Epoch 29/1000, Loss: 1.6938 Epoch 30/1000, Loss: 1.6665 Epoch 31/1000, Loss: 1.6391 Epoch 32/1000, Loss: 1.6115 Epoch 33/1000, Loss: 1.5839 Epoch 34/1000, Loss: 1.5562 Epoch 35/1000, Loss: 1.5285 Epoch 36/1000, Loss: 1.5008 Epoch 37/1000, Loss: 1.4731 Epoch 38/1000, Loss: 1.4454 Epoch 39/1000, Loss: 1.4179 Epoch 40/1000, Loss: 1.3904 Epoch 41/1000, Loss: 1.3630 Epoch 42/1000, Loss: 1.3359 Epoch 43/1000, Loss: 1.3090 Epoch 44/1000, Loss: 1.2824 Epoch 45/1000, Loss: 1.2561 Epoch 46/1000, Loss: 1.2302 Epoch 47/1000, Loss: 1.2046 Epoch 48/1000, Loss: 1.1794 Epoch 49/1000, Loss: 1.1546 Epoch 50/1000, Loss: 1.1303 Epoch 51/1000, Loss: 1.1064 Epoch 52/1000, Loss: 1.0831 Epoch 53/1000, Loss: 1.0602 Epoch 54/1000, Loss: 1.0379 Epoch 55/1000, Loss: 1.0162 Epoch 56/1000, Loss: 0.9950 Epoch 57/1000, Loss: 0.9745 Epoch 58/1000, Loss: 0.9545 Epoch 59/1000, Loss: 0.9350 Epoch 60/1000, Loss: 0.9162 Epoch 61/1000, Loss: 0.8978 Epoch 62/1000, Loss: 0.8800 Epoch 63/1000, Loss: 0.8627 Epoch 64/1000, Loss: 0.8459 Epoch 65/1000, Loss: 0.8296 Epoch 66/1000, Loss: 0.8138 Epoch 67/1000, Loss: 0.7984 Epoch 68/1000, Loss: 0.7836 Epoch 69/1000, Loss: 0.7692 Epoch 70/1000, Loss: 0.7553 Epoch 71/1000, Loss: 0.7419 Epoch 72/1000, Loss: 0.7289 Epoch 73/1000, Loss: 0.7163 Epoch 74/1000, Loss: 0.7042 Epoch 75/1000, Loss: 0.6925 Epoch 76/1000, Loss: 0.6811 Epoch 77/1000, Loss: 0.6702 Epoch 78/1000, Loss: 0.6596 Epoch 79/1000, Loss: 0.6494 Epoch 80/1000, Loss: 0.6395 Epoch 81/1000, Loss: 0.6300 Epoch 82/1000, Loss: 0.6209 Epoch 83/1000, Loss: 0.6120 Epoch 84/1000, Loss: 0.6035 Epoch 85/1000, Loss: 0.5953 Epoch 86/1000, Loss: 0.5873 Epoch 87/1000, Loss: 0.5797 Epoch 88/1000, Loss: 0.5723 Epoch 89/1000, Loss: 0.5651 Epoch 90/1000, Loss: 0.5582 Epoch 91/1000, Loss: 0.5515 Epoch 92/1000, Loss: 0.5451 Epoch 93/1000, Loss: 0.5389 Epoch 94/1000, Loss: 0.5329 Epoch 95/1000, Loss: 0.5270 Epoch 96/1000, Loss: 0.5214 Epoch 97/1000, Loss: 0.5160 Epoch 98/1000, Loss: 0.5107 Epoch 99/1000, Loss: 0.5056 Epoch 100/1000, Loss: 0.5006 Epoch 101/1000, Loss: 0.4958 Epoch 102/1000, Loss: 0.4912 Epoch 103/1000, Loss: 0.4866 Epoch 104/1000, Loss: 0.4823 Epoch 105/1000, Loss: 0.4780 Epoch 106/1000, Loss: 0.4739 Epoch 107/1000, Loss: 0.4699 Epoch 108/1000, Loss: 0.4659 Epoch 109/1000, Loss: 0.4621 Epoch 110/1000, Loss: 0.4584 Epoch 111/1000, Loss: 0.4548 Epoch 112/1000, Loss: 0.4513 Epoch 113/1000, Loss: 0.4479 Epoch 114/1000, Loss: 0.4446 Epoch 115/1000, Loss: 0.4413 Epoch 116/1000, Loss: 0.4382 Epoch 117/1000, Loss: 0.4351 Epoch 118/1000, Loss: 0.4321 Epoch 119/1000, Loss: 0.4291 Epoch 120/1000, Loss: 0.4262 Epoch 121/1000, Loss: 0.4234 Epoch 122/1000, Loss: 0.4207 Epoch 123/1000, Loss: 0.4180 Epoch 124/1000, Loss: 0.4154 Epoch 125/1000, Loss: 0.4128 Epoch 126/1000, Loss: 0.4103 Epoch 127/1000, Loss: 0.4079 Epoch 128/1000, Loss: 0.4054 Epoch 129/1000, Loss: 0.4031 Epoch 130/1000, Loss: 0.4008 Epoch 131/1000, Loss: 0.3985 Epoch 132/1000, Loss: 0.3963 Epoch 133/1000, Loss: 0.3941 Epoch 134/1000, Loss: 0.3920 Epoch 135/1000, Loss: 0.3899 Epoch 136/1000, Loss: 0.3879 Epoch 137/1000, Loss: 0.3859 Epoch 138/1000, Loss: 0.3839 Epoch 139/1000, Loss: 0.3819 Epoch 140/1000, Loss: 0.3800 Epoch 141/1000, Loss: 0.3782 Epoch 142/1000, Loss: 0.3763 Epoch 143/1000, Loss: 0.3745 Epoch 144/1000, Loss: 0.3728 Epoch 145/1000, Loss: 0.3710 Epoch 146/1000, Loss: 0.3693 Epoch 147/1000, Loss: 0.3676 Epoch 148/1000, Loss: 0.3660 Epoch 149/1000, Loss: 0.3643 Epoch 150/1000, Loss: 0.3627 Epoch 151/1000, Loss: 0.3611 Epoch 152/1000, Loss: 0.3596 Epoch 153/1000, Loss: 0.3580 Epoch 154/1000, Loss: 0.3565 Epoch 155/1000, Loss: 0.3551 Epoch 156/1000, Loss: 0.3536 Epoch 157/1000, Loss: 0.3521 Epoch 158/1000, Loss: 0.3507 Epoch 159/1000, Loss: 0.3493 Epoch 160/1000, Loss: 0.3479 Epoch 161/1000, Loss: 0.3466 Epoch 162/1000, Loss: 0.3452 Epoch 163/1000, Loss: 0.3439 Epoch 164/1000, Loss: 0.3426 Epoch 165/1000, Loss: 0.3413 Epoch 166/1000, Loss: 0.3400 Epoch 167/1000, Loss: 0.3388 Epoch 168/1000, Loss: 0.3375 Epoch 169/1000, Loss: 0.3363 Epoch 170/1000, Loss: 0.3351 Epoch 171/1000, Loss: 0.3339 Epoch 172/1000, Loss: 0.3327 Epoch 173/1000, Loss: 0.3316 Epoch 174/1000, Loss: 0.3304 Epoch 175/1000, Loss: 0.3293 Epoch 176/1000, Loss: 0.3281 Epoch 177/1000, Loss: 0.3270 Epoch 178/1000, Loss: 0.3259 Epoch 179/1000, Loss: 0.3249 Epoch 180/1000, Loss: 0.3238 Epoch 181/1000, Loss: 0.3227 Epoch 182/1000, Loss: 0.3217 Epoch 183/1000, Loss: 0.3206 Epoch 184/1000, Loss: 0.3196 Epoch 185/1000, Loss: 0.3186 Epoch 186/1000, Loss: 0.3176 Epoch 187/1000, Loss: 0.3166 Epoch 188/1000, Loss: 0.3156 Epoch 189/1000, Loss: 0.3147 Epoch 190/1000, Loss: 0.3137 Epoch 191/1000, Loss: 0.3127 Epoch 192/1000, Loss: 0.3118 Epoch 193/1000, Loss: 0.3109 Epoch 194/1000, Loss: 0.3099 Epoch 195/1000, Loss: 0.3090 Epoch 196/1000, Loss: 0.3081 Epoch 197/1000, Loss: 0.3072 Epoch 198/1000, Loss: 0.3063 Epoch 199/1000, Loss: 0.3055 Epoch 200/1000, Loss: 0.3046 Epoch 201/1000, Loss: 0.3037 Epoch 202/1000, Loss: 0.3029 Epoch 203/1000, Loss: 0.3020 Epoch 204/1000, Loss: 0.3012 Epoch 205/1000, Loss: 0.3004 Epoch 206/1000, Loss: 0.2995 Epoch 207/1000, Loss: 0.2987 Epoch 208/1000, Loss: 0.2979 Epoch 209/1000, Loss: 0.2971 Epoch 210/1000, Loss: 0.2963 Epoch 211/1000, Loss: 0.2955 Epoch 212/1000, Loss: 0.2947 Epoch 213/1000, Loss: 0.2939 Epoch 214/1000, Loss: 0.2931 Epoch 215/1000, Loss: 0.2924 Epoch 216/1000, Loss: 0.2916 Epoch 217/1000, Loss: 0.2908 Epoch 218/1000, Loss: 0.2901 Epoch 219/1000, Loss: 0.2893 Epoch 220/1000, Loss: 0.2886 Epoch 221/1000, Loss: 0.2878 Epoch 222/1000, Loss: 0.2871 Epoch 223/1000, Loss: 0.2864 Epoch 224/1000, Loss: 0.2856 Epoch 225/1000, Loss: 0.2849 Epoch 226/1000, Loss: 0.2842 Epoch 227/1000, Loss: 0.2835 Epoch 228/1000, Loss: 0.2827 Epoch 229/1000, Loss: 0.2820 Epoch 230/1000, Loss: 0.2813 Epoch 231/1000, Loss: 0.2806 Epoch 232/1000, Loss: 0.2799 Epoch 233/1000, Loss: 0.2792 Epoch 234/1000, Loss: 0.2785 Epoch 235/1000, Loss: 0.2778 Epoch 236/1000, Loss: 0.2771 Epoch 237/1000, Loss: 0.2764 Epoch 238/1000, Loss: 0.2757 Epoch 239/1000, Loss: 0.2750 Epoch 240/1000, Loss: 0.2743 Epoch 241/1000, Loss: 0.2736 Epoch 242/1000, Loss: 0.2729 Epoch 243/1000, Loss: 0.2722 Epoch 244/1000, Loss: 0.2715 Epoch 245/1000, Loss: 0.2708 Epoch 246/1000, Loss: 0.2702 Epoch 247/1000, Loss: 0.2695 Epoch 248/1000, Loss: 0.2688 Epoch 249/1000, Loss: 0.2681 Epoch 250/1000, Loss: 0.2674 Epoch 251/1000, Loss: 0.2668 Epoch 252/1000, Loss: 0.2661 Epoch 253/1000, Loss: 0.2654 Epoch 254/1000, Loss: 0.2647 Epoch 255/1000, Loss: 0.2641 Epoch 256/1000, Loss: 0.2634 Epoch 257/1000, Loss: 0.2627 Epoch 258/1000, Loss: 0.2621 Epoch 259/1000, Loss: 0.2614 Epoch 260/1000, Loss: 0.2607 Epoch 261/1000, Loss: 0.2601 Epoch 262/1000, Loss: 0.2594 Epoch 263/1000, Loss: 0.2587 Epoch 264/1000, Loss: 0.2581 Epoch 265/1000, Loss: 0.2574 Epoch 266/1000, Loss: 0.2568 Epoch 267/1000, Loss: 0.2561 Epoch 268/1000, Loss: 0.2555 Epoch 269/1000, Loss: 0.2548 Epoch 270/1000, Loss: 0.2542 Epoch 271/1000, Loss: 0.2535 Epoch 272/1000, Loss: 0.2529 Epoch 273/1000, Loss: 0.2522 Epoch 274/1000, Loss: 0.2516 Epoch 275/1000, Loss: 0.2509 Epoch 276/1000, Loss: 0.2503 Epoch 277/1000, Loss: 0.2496 Epoch 278/1000, Loss: 0.2490 Epoch 279/1000, Loss: 0.2484 Epoch 280/1000, Loss: 0.2477 Epoch 281/1000, Loss: 0.2471 Epoch 282/1000, Loss: 0.2465 Epoch 283/1000, Loss: 0.2459 Epoch 284/1000, Loss: 0.2452 Epoch 285/1000, Loss: 0.2446 Epoch 286/1000, Loss: 0.2440 Epoch 287/1000, Loss: 0.2434 Epoch 288/1000, Loss: 0.2427 Epoch 289/1000, Loss: 0.2421 Epoch 290/1000, Loss: 0.2415 Epoch 291/1000, Loss: 0.2409 Epoch 292/1000, Loss: 0.2403 Epoch 293/1000, Loss: 0.2397 Epoch 294/1000, Loss: 0.2391 Epoch 295/1000, Loss: 0.2385 Epoch 296/1000, Loss: 0.2379 Epoch 297/1000, Loss: 0.2373 Epoch 298/1000, Loss: 0.2367 Epoch 299/1000, Loss: 0.2361 Epoch 300/1000, Loss: 0.2355 Epoch 301/1000, Loss: 0.2349 Epoch 302/1000, Loss: 0.2343 Epoch 303/1000, Loss: 0.2337 Epoch 304/1000, Loss: 0.2331 Epoch 305/1000, Loss: 0.2325 Epoch 306/1000, Loss: 0.2319 Epoch 307/1000, Loss: 0.2313 Epoch 308/1000, Loss: 0.2307 Epoch 309/1000, Loss: 0.2301 Epoch 310/1000, Loss: 0.2295 Epoch 311/1000, Loss: 0.2290 Epoch 312/1000, Loss: 0.2284 Epoch 313/1000, Loss: 0.2278 Epoch 314/1000, Loss: 0.2272 Epoch 315/1000, Loss: 0.2266 Epoch 316/1000, Loss: 0.2260 Epoch 317/1000, Loss: 0.2255 Epoch 318/1000, Loss: 0.2249 Epoch 319/1000, Loss: 0.2243 Epoch 320/1000, Loss: 0.2237 Epoch 321/1000, Loss: 0.2232 Epoch 322/1000, Loss: 0.2226 Epoch 323/1000, Loss: 0.2220 Epoch 324/1000, Loss: 0.2214 Epoch 325/1000, Loss: 0.2209 Epoch 326/1000, Loss: 0.2203 Epoch 327/1000, Loss: 0.2197 Epoch 328/1000, Loss: 0.2191 Epoch 329/1000, Loss: 0.2186 Epoch 330/1000, Loss: 0.2180 Epoch 331/1000, Loss: 0.2174 Epoch 332/1000, Loss: 0.2168 Epoch 333/1000, Loss: 0.2162 Epoch 334/1000, Loss: 0.2156 Epoch 335/1000, Loss: 0.2151 Epoch 336/1000, Loss: 0.2145 Epoch 337/1000, Loss: 0.2139 Epoch 338/1000, Loss: 0.2133 Epoch 339/1000, Loss: 0.2127 Epoch 340/1000, Loss: 0.2121 Epoch 341/1000, Loss: 0.2116 Epoch 342/1000, Loss: 0.2110 Epoch 343/1000, Loss: 0.2104 Epoch 344/1000, Loss: 0.2098 Epoch 345/1000, Loss: 0.2093 Epoch 346/1000, Loss: 0.2087 Epoch 347/1000, Loss: 0.2081 Epoch 348/1000, Loss: 0.2075 Epoch 349/1000, Loss: 0.2070 Epoch 350/1000, Loss: 0.2064 Epoch 351/1000, Loss: 0.2058 Epoch 352/1000, Loss: 0.2053 Epoch 353/1000, Loss: 0.2047 Epoch 354/1000, Loss: 0.2042 Epoch 355/1000, Loss: 0.2036 Epoch 356/1000, Loss: 0.2031 Epoch 357/1000, Loss: 0.2025 Epoch 358/1000, Loss: 0.2020 Epoch 359/1000, Loss: 0.2014 Epoch 360/1000, Loss: 0.2009 Epoch 361/1000, Loss: 0.2004 Epoch 362/1000, Loss: 0.1999 Epoch 363/1000, Loss: 0.1993 Epoch 364/1000, Loss: 0.1988 Epoch 365/1000, Loss: 0.1983 Epoch 366/1000, Loss: 0.1978 Epoch 367/1000, Loss: 0.1973 Epoch 368/1000, Loss: 0.1967 Epoch 369/1000, Loss: 0.1962 Epoch 370/1000, Loss: 0.1957 Epoch 371/1000, Loss: 0.1952 Epoch 372/1000, Loss: 0.1947 Epoch 373/1000, Loss: 0.1942 Epoch 374/1000, Loss: 0.1937 Epoch 375/1000, Loss: 0.1932 Epoch 376/1000, Loss: 0.1927 Epoch 377/1000, Loss: 0.1923 Epoch 378/1000, Loss: 0.1918 Epoch 379/1000, Loss: 0.1913 Epoch 380/1000, Loss: 0.1908 Epoch 381/1000, Loss: 0.1903 Epoch 382/1000, Loss: 0.1898 Epoch 383/1000, Loss: 0.1894 Epoch 384/1000, Loss: 0.1889 Epoch 385/1000, Loss: 0.1884 Epoch 386/1000, Loss: 0.1880 Epoch 387/1000, Loss: 0.1875 Epoch 388/1000, Loss: 0.1870 Epoch 389/1000, Loss: 0.1866 Epoch 390/1000, Loss: 0.1861 Epoch 391/1000, Loss: 0.1857 Epoch 392/1000, Loss: 0.1852 Epoch 393/1000, Loss: 0.1848 Epoch 394/1000, Loss: 0.1843 Epoch 395/1000, Loss: 0.1839 Epoch 396/1000, Loss: 0.1834 Epoch 397/1000, Loss: 0.1830 Epoch 398/1000, Loss: 0.1825 Epoch 399/1000, Loss: 0.1821 Epoch 400/1000, Loss: 0.1817 Epoch 401/1000, Loss: 0.1812 Epoch 402/1000, Loss: 0.1808 Epoch 403/1000, Loss: 0.1804 Epoch 404/1000, Loss: 0.1799 Epoch 405/1000, Loss: 0.1795 Epoch 406/1000, Loss: 0.1791 Epoch 407/1000, Loss: 0.1787 Epoch 408/1000, Loss: 0.1782 Epoch 409/1000, Loss: 0.1778 Epoch 410/1000, Loss: 0.1774 Epoch 411/1000, Loss: 0.1770 Epoch 412/1000, Loss: 0.1766 Epoch 413/1000, Loss: 0.1762 Epoch 414/1000, Loss: 0.1758 Epoch 415/1000, Loss: 0.1753 Epoch 416/1000, Loss: 0.1749 Epoch 417/1000, Loss: 0.1745 Epoch 418/1000, Loss: 0.1741 Epoch 419/1000, Loss: 0.1737 Epoch 420/1000, Loss: 0.1733 Epoch 421/1000, Loss: 0.1729 Epoch 422/1000, Loss: 0.1725 Epoch 423/1000, Loss: 0.1722 Epoch 424/1000, Loss: 0.1718 Epoch 425/1000, Loss: 0.1714 Epoch 426/1000, Loss: 0.1710 Epoch 427/1000, Loss: 0.1706 Epoch 428/1000, Loss: 0.1702 Epoch 429/1000, Loss: 0.1698 Epoch 430/1000, Loss: 0.1694 Epoch 431/1000, Loss: 0.1691 Epoch 432/1000, Loss: 0.1687 Epoch 433/1000, Loss: 0.1683 Epoch 434/1000, Loss: 0.1679 Epoch 435/1000, Loss: 0.1676 Epoch 436/1000, Loss: 0.1672 Epoch 437/1000, Loss: 0.1668 Epoch 438/1000, Loss: 0.1664 Epoch 439/1000, Loss: 0.1661 Epoch 440/1000, Loss: 0.1657 Epoch 441/1000, Loss: 0.1653 Epoch 442/1000, Loss: 0.1650 Epoch 443/1000, Loss: 0.1646 Epoch 444/1000, Loss: 0.1642 Epoch 445/1000, Loss: 0.1639 Epoch 446/1000, Loss: 0.1635 Epoch 447/1000, Loss: 0.1632 Epoch 448/1000, Loss: 0.1628 Epoch 449/1000, Loss: 0.1625 Epoch 450/1000, Loss: 0.1621 Epoch 451/1000, Loss: 0.1618 Epoch 452/1000, Loss: 0.1614 Epoch 453/1000, Loss: 0.1611 Epoch 454/1000, Loss: 0.1607 Epoch 455/1000, Loss: 0.1604 Epoch 456/1000, Loss: 0.1600 Epoch 457/1000, Loss: 0.1597 Epoch 458/1000, Loss: 0.1593 Epoch 459/1000, Loss: 0.1590 Epoch 460/1000, Loss: 0.1587 Epoch 461/1000, Loss: 0.1583 Epoch 462/1000, Loss: 0.1580 Epoch 463/1000, Loss: 0.1577 Epoch 464/1000, Loss: 0.1573 Epoch 465/1000, Loss: 0.1570 Epoch 466/1000, Loss: 0.1567 Epoch 467/1000, Loss: 0.1563 Epoch 468/1000, Loss: 0.1560 Epoch 469/1000, Loss: 0.1557 Epoch 470/1000, Loss: 0.1554 Epoch 471/1000, Loss: 0.1551 Epoch 472/1000, Loss: 0.1547 Epoch 473/1000, Loss: 0.1544 Epoch 474/1000, Loss: 0.1541 Epoch 475/1000, Loss: 0.1538 Epoch 476/1000, Loss: 0.1535 Epoch 477/1000, Loss: 0.1532 Epoch 478/1000, Loss: 0.1529 Epoch 479/1000, Loss: 0.1525 Epoch 480/1000, Loss: 0.1522 Epoch 481/1000, Loss: 0.1519 Epoch 482/1000, Loss: 0.1516 Epoch 483/1000, Loss: 0.1513 Epoch 484/1000, Loss: 0.1510 Epoch 485/1000, Loss: 0.1507 Epoch 486/1000, Loss: 0.1504 Epoch 487/1000, Loss: 0.1501 Epoch 488/1000, Loss: 0.1498 Epoch 489/1000, Loss: 0.1495 Epoch 490/1000, Loss: 0.1492 Epoch 491/1000, Loss: 0.1489 Epoch 492/1000, Loss: 0.1486 Epoch 493/1000, Loss: 0.1484 Epoch 494/1000, Loss: 0.1481 Epoch 495/1000, Loss: 0.1478 Epoch 496/1000, Loss: 0.1475 Epoch 497/1000, Loss: 0.1472 Epoch 498/1000, Loss: 0.1469 Epoch 499/1000, Loss: 0.1466 Epoch 500/1000, Loss: 0.1463 Epoch 501/1000, Loss: 0.1461 Epoch 502/1000, Loss: 0.1458 Epoch 503/1000, Loss: 0.1455 Epoch 504/1000, Loss: 0.1452 Epoch 505/1000, Loss: 0.1449 Epoch 506/1000, Loss: 0.1447 Epoch 507/1000, Loss: 0.1444 Epoch 508/1000, Loss: 0.1441 Epoch 509/1000, Loss: 0.1438 Epoch 510/1000, Loss: 0.1436 Epoch 511/1000, Loss: 0.1433 Epoch 512/1000, Loss: 0.1430 Epoch 513/1000, Loss: 0.1428 Epoch 514/1000, Loss: 0.1425 Epoch 515/1000, Loss: 0.1422 Epoch 516/1000, Loss: 0.1420 Epoch 517/1000, Loss: 0.1417 Epoch 518/1000, Loss: 0.1414 Epoch 519/1000, Loss: 0.1412 Epoch 520/1000, Loss: 0.1409 Epoch 521/1000, Loss: 0.1406 Epoch 522/1000, Loss: 0.1404 Epoch 523/1000, Loss: 0.1401 Epoch 524/1000, Loss: 0.1399 Epoch 525/1000, Loss: 0.1396 Epoch 526/1000, Loss: 0.1394 Epoch 527/1000, Loss: 0.1391 Epoch 528/1000, Loss: 0.1389 Epoch 529/1000, Loss: 0.1386 Epoch 530/1000, Loss: 0.1384 Epoch 531/1000, Loss: 0.1381 Epoch 532/1000, Loss: 0.1379 Epoch 533/1000, Loss: 0.1376 Epoch 534/1000, Loss: 0.1374 Epoch 535/1000, Loss: 0.1371 Epoch 536/1000, Loss: 0.1369 Epoch 537/1000, Loss: 0.1366 Epoch 538/1000, Loss: 0.1364 Epoch 539/1000, Loss: 0.1361 Epoch 540/1000, Loss: 0.1359 Epoch 541/1000, Loss: 0.1357 Epoch 542/1000, Loss: 0.1354 Epoch 543/1000, Loss: 0.1352 Epoch 544/1000, Loss: 0.1349 Epoch 545/1000, Loss: 0.1347 Epoch 546/1000, Loss: 0.1345 Epoch 547/1000, Loss: 0.1342 Epoch 548/1000, Loss: 0.1340 Epoch 549/1000, Loss: 0.1337 Epoch 550/1000, Loss: 0.1335 Epoch 551/1000, Loss: 0.1333 Epoch 552/1000, Loss: 0.1330 Epoch 553/1000, Loss: 0.1328 Epoch 554/1000, Loss: 0.1326 Epoch 555/1000, Loss: 0.1324 Epoch 556/1000, Loss: 0.1321 Epoch 557/1000, Loss: 0.1319 Epoch 558/1000, Loss: 0.1317 Epoch 559/1000, Loss: 0.1314 Epoch 560/1000, Loss: 0.1312 Epoch 561/1000, Loss: 0.1310 Epoch 562/1000, Loss: 0.1308 Epoch 563/1000, Loss: 0.1305 Epoch 564/1000, Loss: 0.1303 Epoch 565/1000, Loss: 0.1301 Epoch 566/1000, Loss: 0.1299 Epoch 567/1000, Loss: 0.1296 Epoch 568/1000, Loss: 0.1294 Epoch 569/1000, Loss: 0.1292 Epoch 570/1000, Loss: 0.1290 Epoch 571/1000, Loss: 0.1288 Epoch 572/1000, Loss: 0.1285 Epoch 573/1000, Loss: 0.1283 Epoch 574/1000, Loss: 0.1281 Epoch 575/1000, Loss: 0.1279 Epoch 576/1000, Loss: 0.1277 Epoch 577/1000, Loss: 0.1275 Epoch 578/1000, Loss: 0.1272 Epoch 579/1000, Loss: 0.1270 Epoch 580/1000, Loss: 0.1268 Epoch 581/1000, Loss: 0.1266 Epoch 582/1000, Loss: 0.1264 Epoch 583/1000, Loss: 0.1262 Epoch 584/1000, Loss: 0.1260 Epoch 585/1000, Loss: 0.1258 Epoch 586/1000, Loss: 0.1256 Epoch 587/1000, Loss: 0.1253 Epoch 588/1000, Loss: 0.1251 Epoch 589/1000, Loss: 0.1249 Epoch 590/1000, Loss: 0.1247 Epoch 591/1000, Loss: 0.1245 Epoch 592/1000, Loss: 0.1243 Epoch 593/1000, Loss: 0.1241 Epoch 594/1000, Loss: 0.1239 Epoch 595/1000, Loss: 0.1237 Epoch 596/1000, Loss: 0.1235 Epoch 597/1000, Loss: 0.1233 Epoch 598/1000, Loss: 0.1231 Epoch 599/1000, Loss: 0.1229 Epoch 600/1000, Loss: 0.1227 Epoch 601/1000, Loss: 0.1225 Epoch 602/1000, Loss: 0.1223 Epoch 603/1000, Loss: 0.1221 Epoch 604/1000, Loss: 0.1219 Epoch 605/1000, Loss: 0.1217 Epoch 606/1000, Loss: 0.1215 Epoch 607/1000, Loss: 0.1213 Epoch 608/1000, Loss: 0.1211 Epoch 609/1000, Loss: 0.1209 Epoch 610/1000, Loss: 0.1207 Epoch 611/1000, Loss: 0.1205 Epoch 612/1000, Loss: 0.1203 Epoch 613/1000, Loss: 0.1201 Epoch 614/1000, Loss: 0.1199 Epoch 615/1000, Loss: 0.1197 Epoch 616/1000, Loss: 0.1195 Epoch 617/1000, Loss: 0.1193 Epoch 618/1000, Loss: 0.1191 Epoch 619/1000, Loss: 0.1189 Epoch 620/1000, Loss: 0.1187 Epoch 621/1000, Loss: 0.1185 Epoch 622/1000, Loss: 0.1183 Epoch 623/1000, Loss: 0.1181 Epoch 624/1000, Loss: 0.1179 Epoch 625/1000, Loss: 0.1178 Epoch 626/1000, Loss: 0.1176 Epoch 627/1000, Loss: 0.1174 Epoch 628/1000, Loss: 0.1172 Epoch 629/1000, Loss: 0.1170 Epoch 630/1000, Loss: 0.1168 Epoch 631/1000, Loss: 0.1166 Epoch 632/1000, Loss: 0.1164 Epoch 633/1000, Loss: 0.1162 Epoch 634/1000, Loss: 0.1160 Epoch 635/1000, Loss: 0.1159 Epoch 636/1000, Loss: 0.1157 Epoch 637/1000, Loss: 0.1155 Epoch 638/1000, Loss: 0.1153 Epoch 639/1000, Loss: 0.1151 Epoch 640/1000, Loss: 0.1149 Epoch 641/1000, Loss: 0.1148 Epoch 642/1000, Loss: 0.1146 Epoch 643/1000, Loss: 0.1144 Epoch 644/1000, Loss: 0.1142 Epoch 645/1000, Loss: 0.1140 Epoch 646/1000, Loss: 0.1138 Epoch 647/1000, Loss: 0.1137 Epoch 648/1000, Loss: 0.1135 Epoch 649/1000, Loss: 0.1133 Epoch 650/1000, Loss: 0.1131 Epoch 651/1000, Loss: 0.1129 Epoch 652/1000, Loss: 0.1128 Epoch 653/1000, Loss: 0.1126 Epoch 654/1000, Loss: 0.1124 Epoch 655/1000, Loss: 0.1122 Epoch 656/1000, Loss: 0.1120 Epoch 657/1000, Loss: 0.1119 Epoch 658/1000, Loss: 0.1117 Epoch 659/1000, Loss: 0.1115 Epoch 660/1000, Loss: 0.1113 Epoch 661/1000, Loss: 0.1112 Epoch 662/1000, Loss: 0.1110 Epoch 663/1000, Loss: 0.1108 Epoch 664/1000, Loss: 0.1106 Epoch 665/1000, Loss: 0.1104 Epoch 666/1000, Loss: 0.1103 Epoch 667/1000, Loss: 0.1101 Epoch 668/1000, Loss: 0.1099 Epoch 669/1000, Loss: 0.1097 Epoch 670/1000, Loss: 0.1096 Epoch 671/1000, Loss: 0.1094 Epoch 672/1000, Loss: 0.1092 Epoch 673/1000, Loss: 0.1090 Epoch 674/1000, Loss: 0.1088 Epoch 675/1000, Loss: 0.1087 Epoch 676/1000, Loss: 0.1085 Epoch 677/1000, Loss: 0.1083 Epoch 678/1000, Loss: 0.1081 Epoch 679/1000, Loss: 0.1080 Epoch 680/1000, Loss: 0.1078 Epoch 681/1000, Loss: 0.1076 Epoch 682/1000, Loss: 0.1074 Epoch 683/1000, Loss: 0.1072 Epoch 684/1000, Loss: 0.1071 Epoch 685/1000, Loss: 0.1069 Epoch 686/1000, Loss: 0.1067 Epoch 687/1000, Loss: 0.1065 Epoch 688/1000, Loss: 0.1064 Epoch 689/1000, Loss: 0.1062 Epoch 690/1000, Loss: 0.1060 Epoch 691/1000, Loss: 0.1058 Epoch 692/1000, Loss: 0.1057 Epoch 693/1000, Loss: 0.1055 Epoch 694/1000, Loss: 0.1053 Epoch 695/1000, Loss: 0.1051 Epoch 696/1000, Loss: 0.1049 Epoch 697/1000, Loss: 0.1048 Epoch 698/1000, Loss: 0.1046 Epoch 699/1000, Loss: 0.1044 Epoch 700/1000, Loss: 0.1042 Epoch 701/1000, Loss: 0.1041 Epoch 702/1000, Loss: 0.1039 Epoch 703/1000, Loss: 0.1037 Epoch 704/1000, Loss: 0.1035 Epoch 705/1000, Loss: 0.1033 Epoch 706/1000, Loss: 0.1032 Epoch 707/1000, Loss: 0.1030 Epoch 708/1000, Loss: 0.1028 Epoch 709/1000, Loss: 0.1026 Epoch 710/1000, Loss: 0.1024 Epoch 711/1000, Loss: 0.1023 Epoch 712/1000, Loss: 0.1021 Epoch 713/1000, Loss: 0.1019 Epoch 714/1000, Loss: 0.1017 Epoch 715/1000, Loss: 0.1016 Epoch 716/1000, Loss: 0.1014 Epoch 717/1000, Loss: 0.1012 Epoch 718/1000, Loss: 0.1011 Epoch 719/1000, Loss: 0.1009 Epoch 720/1000, Loss: 0.1007 Epoch 721/1000, Loss: 0.1005 Epoch 722/1000, Loss: 0.1004 Epoch 723/1000, Loss: 0.1002 Epoch 724/1000, Loss: 0.1000 Epoch 725/1000, Loss: 0.0999 Epoch 726/1000, Loss: 0.0997 Epoch 727/1000, Loss: 0.0995 Epoch 728/1000, Loss: 0.0994 Epoch 729/1000, Loss: 0.0992 Epoch 730/1000, Loss: 0.0990 Epoch 731/1000, Loss: 0.0989 Epoch 732/1000, Loss: 0.0987 Epoch 733/1000, Loss: 0.0985 Epoch 734/1000, Loss: 0.0984 Epoch 735/1000, Loss: 0.0982 Epoch 736/1000, Loss: 0.0981 Epoch 737/1000, Loss: 0.0979 Epoch 738/1000, Loss: 0.0977 Epoch 739/1000, Loss: 0.0976 Epoch 740/1000, Loss: 0.0974 Epoch 741/1000, Loss: 0.0973 Epoch 742/1000, Loss: 0.0971 Epoch 743/1000, Loss: 0.0969 Epoch 744/1000, Loss: 0.0968 Epoch 745/1000, Loss: 0.0966 Epoch 746/1000, Loss: 0.0965 Epoch 747/1000, Loss: 0.0963 Epoch 748/1000, Loss: 0.0962 Epoch 749/1000, Loss: 0.0960 Epoch 750/1000, Loss: 0.0959 Epoch 751/1000, Loss: 0.0957 Epoch 752/1000, Loss: 0.0955 Epoch 753/1000, Loss: 0.0954 Epoch 754/1000, Loss: 0.0952 Epoch 755/1000, Loss: 0.0951 Epoch 756/1000, Loss: 0.0949 Epoch 757/1000, Loss: 0.0948 Epoch 758/1000, Loss: 0.0946 Epoch 759/1000, Loss: 0.0945 Epoch 760/1000, Loss: 0.0943 Epoch 761/1000, Loss: 0.0942 Epoch 762/1000, Loss: 0.0940 Epoch 763/1000, Loss: 0.0939 Epoch 764/1000, Loss: 0.0937 Epoch 765/1000, Loss: 0.0936 Epoch 766/1000, Loss: 0.0934 Epoch 767/1000, Loss: 0.0933 Epoch 768/1000, Loss: 0.0931 Epoch 769/1000, Loss: 0.0930 Epoch 770/1000, Loss: 0.0929 Epoch 771/1000, Loss: 0.0927 Epoch 772/1000, Loss: 0.0926 Epoch 773/1000, Loss: 0.0924 Epoch 774/1000, Loss: 0.0923 Epoch 775/1000, Loss: 0.0921 Epoch 776/1000, Loss: 0.0920 Epoch 777/1000, Loss: 0.0918 Epoch 778/1000, Loss: 0.0917 Epoch 779/1000, Loss: 0.0915 Epoch 780/1000, Loss: 0.0914 Epoch 781/1000, Loss: 0.0913 Epoch 782/1000, Loss: 0.0911 Epoch 783/1000, Loss: 0.0910 Epoch 784/1000, Loss: 0.0908 Epoch 785/1000, Loss: 0.0907 Epoch 786/1000, Loss: 0.0905 Epoch 787/1000, Loss: 0.0904 Epoch 788/1000, Loss: 0.0903 Epoch 789/1000, Loss: 0.0901 Epoch 790/1000, Loss: 0.0900 Epoch 791/1000, Loss: 0.0898 Epoch 792/1000, Loss: 0.0897 Epoch 793/1000, Loss: 0.0896 Epoch 794/1000, Loss: 0.0894 Epoch 795/1000, Loss: 0.0893 Epoch 796/1000, Loss: 0.0892 Epoch 797/1000, Loss: 0.0890 Epoch 798/1000, Loss: 0.0889 Epoch 799/1000, Loss: 0.0887 Epoch 800/1000, Loss: 0.0886 Epoch 801/1000, Loss: 0.0885 Epoch 802/1000, Loss: 0.0883 Epoch 803/1000, Loss: 0.0882 Epoch 804/1000, Loss: 0.0881 Epoch 805/1000, Loss: 0.0879 Epoch 806/1000, Loss: 0.0878 Epoch 807/1000, Loss: 0.0877 Epoch 808/1000, Loss: 0.0875 Epoch 809/1000, Loss: 0.0874 Epoch 810/1000, Loss: 0.0873 Epoch 811/1000, Loss: 0.0871 Epoch 812/1000, Loss: 0.0870 Epoch 813/1000, Loss: 0.0869 Epoch 814/1000, Loss: 0.0867 Epoch 815/1000, Loss: 0.0866 Epoch 816/1000, Loss: 0.0865 Epoch 817/1000, Loss: 0.0863 Epoch 818/1000, Loss: 0.0862 Epoch 819/1000, Loss: 0.0861 Epoch 820/1000, Loss: 0.0859 Epoch 821/1000, Loss: 0.0858 Epoch 822/1000, Loss: 0.0857 Epoch 823/1000, Loss: 0.0856 Epoch 824/1000, Loss: 0.0854 Epoch 825/1000, Loss: 0.0853 Epoch 826/1000, Loss: 0.0852 Epoch 827/1000, Loss: 0.0850 Epoch 828/1000, Loss: 0.0849 Epoch 829/1000, Loss: 0.0848 Epoch 830/1000, Loss: 0.0847 Epoch 831/1000, Loss: 0.0845 Epoch 832/1000, Loss: 0.0844 Epoch 833/1000, Loss: 0.0843 Epoch 834/1000, Loss: 0.0842 Epoch 835/1000, Loss: 0.0840 Epoch 836/1000, Loss: 0.0839 Epoch 837/1000, Loss: 0.0838 Epoch 838/1000, Loss: 0.0837 Epoch 839/1000, Loss: 0.0835 Epoch 840/1000, Loss: 0.0834 Epoch 841/1000, Loss: 0.0833 Epoch 842/1000, Loss: 0.0832 Epoch 843/1000, Loss: 0.0830 Epoch 844/1000, Loss: 0.0829 Epoch 845/1000, Loss: 0.0828 Epoch 846/1000, Loss: 0.0827 Epoch 847/1000, Loss: 0.0826 Epoch 848/1000, Loss: 0.0824 Epoch 849/1000, Loss: 0.0823 Epoch 850/1000, Loss: 0.0822 Epoch 851/1000, Loss: 0.0821 Epoch 852/1000, Loss: 0.0819 Epoch 853/1000, Loss: 0.0818 Epoch 854/1000, Loss: 0.0817 Epoch 855/1000, Loss: 0.0816 Epoch 856/1000, Loss: 0.0815 Epoch 857/1000, Loss: 0.0813 Epoch 858/1000, Loss: 0.0812 Epoch 859/1000, Loss: 0.0811 Epoch 860/1000, Loss: 0.0810 Epoch 861/1000, Loss: 0.0809 Epoch 862/1000, Loss: 0.0808 Epoch 863/1000, Loss: 0.0806 Epoch 864/1000, Loss: 0.0805 Epoch 865/1000, Loss: 0.0804 Epoch 866/1000, Loss: 0.0803 Epoch 867/1000, Loss: 0.0802 Epoch 868/1000, Loss: 0.0800 Epoch 869/1000, Loss: 0.0799 Epoch 870/1000, Loss: 0.0798 Epoch 871/1000, Loss: 0.0797 Epoch 872/1000, Loss: 0.0796 Epoch 873/1000, Loss: 0.0795 Epoch 874/1000, Loss: 0.0794 Epoch 875/1000, Loss: 0.0792 Epoch 876/1000, Loss: 0.0791 Epoch 877/1000, Loss: 0.0790 Epoch 878/1000, Loss: 0.0789 Epoch 879/1000, Loss: 0.0788 Epoch 880/1000, Loss: 0.0787 Epoch 881/1000, Loss: 0.0785 Epoch 882/1000, Loss: 0.0784 Epoch 883/1000, Loss: 0.0783 Epoch 884/1000, Loss: 0.0782 Epoch 885/1000, Loss: 0.0781 Epoch 886/1000, Loss: 0.0780 Epoch 887/1000, Loss: 0.0779 Epoch 888/1000, Loss: 0.0778 Epoch 889/1000, Loss: 0.0776 Epoch 890/1000, Loss: 0.0775 Epoch 891/1000, Loss: 0.0774 Epoch 892/1000, Loss: 0.0773 Epoch 893/1000, Loss: 0.0772 Epoch 894/1000, Loss: 0.0771 Epoch 895/1000, Loss: 0.0770 Epoch 896/1000, Loss: 0.0769 Epoch 897/1000, Loss: 0.0768 Epoch 898/1000, Loss: 0.0766 Epoch 899/1000, Loss: 0.0765 Epoch 900/1000, Loss: 0.0764 Epoch 901/1000, Loss: 0.0763 Epoch 902/1000, Loss: 0.0762 Epoch 903/1000, Loss: 0.0761 Epoch 904/1000, Loss: 0.0760 Epoch 905/1000, Loss: 0.0759 Epoch 906/1000, Loss: 0.0758 Epoch 907/1000, Loss: 0.0757 Epoch 908/1000, Loss: 0.0756 Epoch 909/1000, Loss: 0.0754 Epoch 910/1000, Loss: 0.0753 Epoch 911/1000, Loss: 0.0752 Epoch 912/1000, Loss: 0.0751 Epoch 913/1000, Loss: 0.0750 Epoch 914/1000, Loss: 0.0749 Epoch 915/1000, Loss: 0.0748 Epoch 916/1000, Loss: 0.0747 Epoch 917/1000, Loss: 0.0746 Epoch 918/1000, Loss: 0.0745 Epoch 919/1000, Loss: 0.0744 Epoch 920/1000, Loss: 0.0743 Epoch 921/1000, Loss: 0.0742 Epoch 922/1000, Loss: 0.0741 Epoch 923/1000, Loss: 0.0740 Epoch 924/1000, Loss: 0.0739 Epoch 925/1000, Loss: 0.0738 Epoch 926/1000, Loss: 0.0737 Epoch 927/1000, Loss: 0.0735 Epoch 928/1000, Loss: 0.0734 Epoch 929/1000, Loss: 0.0733 Epoch 930/1000, Loss: 0.0732 Epoch 931/1000, Loss: 0.0731 Epoch 932/1000, Loss: 0.0730 Epoch 933/1000, Loss: 0.0729 Epoch 934/1000, Loss: 0.0728 Epoch 935/1000, Loss: 0.0727 Epoch 936/1000, Loss: 0.0726 Epoch 937/1000, Loss: 0.0725 Epoch 938/1000, Loss: 0.0724 Epoch 939/1000, Loss: 0.0723 Epoch 940/1000, Loss: 0.0722 Epoch 941/1000, Loss: 0.0721 Epoch 942/1000, Loss: 0.0720 Epoch 943/1000, Loss: 0.0719 Epoch 944/1000, Loss: 0.0718 Epoch 945/1000, Loss: 0.0717 Epoch 946/1000, Loss: 0.0716 Epoch 947/1000, Loss: 0.0715 Epoch 948/1000, Loss: 0.0714 Epoch 949/1000, Loss: 0.0713 Epoch 950/1000, Loss: 0.0712 Epoch 951/1000, Loss: 0.0711 Epoch 952/1000, Loss: 0.0710 Epoch 953/1000, Loss: 0.0709 Epoch 954/1000, Loss: 0.0708 Epoch 955/1000, Loss: 0.0707 Epoch 956/1000, Loss: 0.0706 Epoch 957/1000, Loss: 0.0705 Epoch 958/1000, Loss: 0.0704 Epoch 959/1000, Loss: 0.0703 Epoch 960/1000, Loss: 0.0702 Epoch 961/1000, Loss: 0.0701 Epoch 962/1000, Loss: 0.0701 Epoch 963/1000, Loss: 0.0700 Epoch 964/1000, Loss: 0.0699 Epoch 965/1000, Loss: 0.0698 Epoch 966/1000, Loss: 0.0697 Epoch 967/1000, Loss: 0.0696 Epoch 968/1000, Loss: 0.0695 Epoch 969/1000, Loss: 0.0694 Epoch 970/1000, Loss: 0.0693 Epoch 971/1000, Loss: 0.0692 Epoch 972/1000, Loss: 0.0691 Epoch 973/1000, Loss: 0.0690 Epoch 974/1000, Loss: 0.0689 Epoch 975/1000, Loss: 0.0688 Epoch 976/1000, Loss: 0.0687 Epoch 977/1000, Loss: 0.0686 Epoch 978/1000, Loss: 0.0685 Epoch 979/1000, Loss: 0.0685 Epoch 980/1000, Loss: 0.0684 Epoch 981/1000, Loss: 0.0683 Epoch 982/1000, Loss: 0.0682 Epoch 983/1000, Loss: 0.0681 Epoch 984/1000, Loss: 0.0680 Epoch 985/1000, Loss: 0.0679 Epoch 986/1000, Loss: 0.0678 Epoch 987/1000, Loss: 0.0677 Epoch 988/1000, Loss: 0.0676 Epoch 989/1000, Loss: 0.0675 Epoch 990/1000, Loss: 0.0674 Epoch 991/1000, Loss: 0.0674 Epoch 992/1000, Loss: 0.0673 Epoch 993/1000, Loss: 0.0672 Epoch 994/1000, Loss: 0.0671 Epoch 995/1000, Loss: 0.0670 Epoch 996/1000, Loss: 0.0669 Epoch 997/1000, Loss: 0.0668 Epoch 998/1000, Loss: 0.0667 Epoch 999/1000, Loss: 0.0666 Epoch 1000/1000, Loss: 0.0665 MLP on MNIST with ReLU Accuracy: 0.9621, F1-score: 0.9618
Test Loss: 0.1322
Using Cross Entropy and Sigmoid
mlp_sigmoid = MLP_sigmoid().to(device)
print("\nTraining MLP...")
train_losses = train_mlp(mlp_sigmoid, train_X.to(device), train_y.to(device), epochs=1000)
preds, acc, f1, cm, test_loss = evaluate_model(mlp_sigmoid, test_X, test_y, device)
summary("MLP on MNIST with Sigmoid", acc, f1, cm, train_losses)
print(f"Test Loss: {test_loss:.4f}")
Training MLP... Epoch 1/1000, Loss: 2.3312 Epoch 2/1000, Loss: 2.3266 Epoch 3/1000, Loss: 2.3222 Epoch 4/1000, Loss: 2.3181 Epoch 5/1000, Loss: 2.3141 Epoch 6/1000, Loss: 2.3103 Epoch 7/1000, Loss: 2.3066 Epoch 8/1000, Loss: 2.3031 Epoch 9/1000, Loss: 2.2998 Epoch 10/1000, Loss: 2.2967 Epoch 11/1000, Loss: 2.2936 Epoch 12/1000, Loss: 2.2907 Epoch 13/1000, Loss: 2.2879 Epoch 14/1000, Loss: 2.2852 Epoch 15/1000, Loss: 2.2826 Epoch 16/1000, Loss: 2.2800 Epoch 17/1000, Loss: 2.2775 Epoch 18/1000, Loss: 2.2750 Epoch 19/1000, Loss: 2.2726 Epoch 20/1000, Loss: 2.2701 Epoch 21/1000, Loss: 2.2677 Epoch 22/1000, Loss: 2.2652 Epoch 23/1000, Loss: 2.2628 Epoch 24/1000, Loss: 2.2603 Epoch 25/1000, Loss: 2.2578 Epoch 26/1000, Loss: 2.2553 Epoch 27/1000, Loss: 2.2528 Epoch 28/1000, Loss: 2.2503 Epoch 29/1000, Loss: 2.2477 Epoch 30/1000, Loss: 2.2450 Epoch 31/1000, Loss: 2.2424 Epoch 32/1000, Loss: 2.2397 Epoch 33/1000, Loss: 2.2370 Epoch 34/1000, Loss: 2.2342 Epoch 35/1000, Loss: 2.2314 Epoch 36/1000, Loss: 2.2285 Epoch 37/1000, Loss: 2.2256 Epoch 38/1000, Loss: 2.2226 Epoch 39/1000, Loss: 2.2196 Epoch 40/1000, Loss: 2.2165 Epoch 41/1000, Loss: 2.2134 Epoch 42/1000, Loss: 2.2102 Epoch 43/1000, Loss: 2.2069 Epoch 44/1000, Loss: 2.2035 Epoch 45/1000, Loss: 2.2001 Epoch 46/1000, Loss: 2.1967 Epoch 47/1000, Loss: 2.1931 Epoch 48/1000, Loss: 2.1895 Epoch 49/1000, Loss: 2.1858 Epoch 50/1000, Loss: 2.1821 Epoch 51/1000, Loss: 2.1783 Epoch 52/1000, Loss: 2.1744 Epoch 53/1000, Loss: 2.1705 Epoch 54/1000, Loss: 2.1665 Epoch 55/1000, Loss: 2.1624 Epoch 56/1000, Loss: 2.1583 Epoch 57/1000, Loss: 2.1541 Epoch 58/1000, Loss: 2.1498 Epoch 59/1000, Loss: 2.1455 Epoch 60/1000, Loss: 2.1411 Epoch 61/1000, Loss: 2.1366 Epoch 62/1000, Loss: 2.1321 Epoch 63/1000, Loss: 2.1275 Epoch 64/1000, Loss: 2.1228 Epoch 65/1000, Loss: 2.1181 Epoch 66/1000, Loss: 2.1133 Epoch 67/1000, Loss: 2.1085 Epoch 68/1000, Loss: 2.1036 Epoch 69/1000, Loss: 2.0986 Epoch 70/1000, Loss: 2.0936 Epoch 71/1000, Loss: 2.0885 Epoch 72/1000, Loss: 2.0834 Epoch 73/1000, Loss: 2.0782 Epoch 74/1000, Loss: 2.0729 Epoch 75/1000, Loss: 2.0676 Epoch 76/1000, Loss: 2.0623 Epoch 77/1000, Loss: 2.0569 Epoch 78/1000, Loss: 2.0514 Epoch 79/1000, Loss: 2.0459 Epoch 80/1000, Loss: 2.0404 Epoch 81/1000, Loss: 2.0348 Epoch 82/1000, Loss: 2.0291 Epoch 83/1000, Loss: 2.0234 Epoch 84/1000, Loss: 2.0177 Epoch 85/1000, Loss: 2.0119 Epoch 86/1000, Loss: 2.0061 Epoch 87/1000, Loss: 2.0003 Epoch 88/1000, Loss: 1.9944 Epoch 89/1000, Loss: 1.9885 Epoch 90/1000, Loss: 1.9826 Epoch 91/1000, Loss: 1.9766 Epoch 92/1000, Loss: 1.9706 Epoch 93/1000, Loss: 1.9646 Epoch 94/1000, Loss: 1.9586 Epoch 95/1000, Loss: 1.9525 Epoch 96/1000, Loss: 1.9465 Epoch 97/1000, Loss: 1.9404 Epoch 98/1000, Loss: 1.9343 Epoch 99/1000, Loss: 1.9282 Epoch 100/1000, Loss: 1.9220 Epoch 101/1000, Loss: 1.9159 Epoch 102/1000, Loss: 1.9097 Epoch 103/1000, Loss: 1.9036 Epoch 104/1000, Loss: 1.8974 Epoch 105/1000, Loss: 1.8913 Epoch 106/1000, Loss: 1.8851 Epoch 107/1000, Loss: 1.8789 Epoch 108/1000, Loss: 1.8728 Epoch 109/1000, Loss: 1.8666 Epoch 110/1000, Loss: 1.8605 Epoch 111/1000, Loss: 1.8543 Epoch 112/1000, Loss: 1.8482 Epoch 113/1000, Loss: 1.8420 Epoch 114/1000, Loss: 1.8359 Epoch 115/1000, Loss: 1.8298 Epoch 116/1000, Loss: 1.8237 Epoch 117/1000, Loss: 1.8176 Epoch 118/1000, Loss: 1.8115 Epoch 119/1000, Loss: 1.8054 Epoch 120/1000, Loss: 1.7994 Epoch 121/1000, Loss: 1.7933 Epoch 122/1000, Loss: 1.7873 Epoch 123/1000, Loss: 1.7813 Epoch 124/1000, Loss: 1.7753 Epoch 125/1000, Loss: 1.7693 Epoch 126/1000, Loss: 1.7634 Epoch 127/1000, Loss: 1.7574 Epoch 128/1000, Loss: 1.7515 Epoch 129/1000, Loss: 1.7456 Epoch 130/1000, Loss: 1.7397 Epoch 131/1000, Loss: 1.7339 Epoch 132/1000, Loss: 1.7280 Epoch 133/1000, Loss: 1.7222 Epoch 134/1000, Loss: 1.7164 Epoch 135/1000, Loss: 1.7107 Epoch 136/1000, Loss: 1.7049 Epoch 137/1000, Loss: 1.6992 Epoch 138/1000, Loss: 1.6935 Epoch 139/1000, Loss: 1.6878 Epoch 140/1000, Loss: 1.6822 Epoch 141/1000, Loss: 1.6765 Epoch 142/1000, Loss: 1.6709 Epoch 143/1000, Loss: 1.6654 Epoch 144/1000, Loss: 1.6598 Epoch 145/1000, Loss: 1.6543 Epoch 146/1000, Loss: 1.6488 Epoch 147/1000, Loss: 1.6433 Epoch 148/1000, Loss: 1.6379 Epoch 149/1000, Loss: 1.6324 Epoch 150/1000, Loss: 1.6271 Epoch 151/1000, Loss: 1.6217 Epoch 152/1000, Loss: 1.6163 Epoch 153/1000, Loss: 1.6110 Epoch 154/1000, Loss: 1.6057 Epoch 155/1000, Loss: 1.6005 Epoch 156/1000, Loss: 1.5952 Epoch 157/1000, Loss: 1.5900 Epoch 158/1000, Loss: 1.5849 Epoch 159/1000, Loss: 1.5797 Epoch 160/1000, Loss: 1.5746 Epoch 161/1000, Loss: 1.5695 Epoch 162/1000, Loss: 1.5644 Epoch 163/1000, Loss: 1.5593 Epoch 164/1000, Loss: 1.5543 Epoch 165/1000, Loss: 1.5493 Epoch 166/1000, Loss: 1.5443 Epoch 167/1000, Loss: 1.5394 Epoch 168/1000, Loss: 1.5345 Epoch 169/1000, Loss: 1.5296 Epoch 170/1000, Loss: 1.5247 Epoch 171/1000, Loss: 1.5199 Epoch 172/1000, Loss: 1.5150 Epoch 173/1000, Loss: 1.5102 Epoch 174/1000, Loss: 1.5055 Epoch 175/1000, Loss: 1.5007 Epoch 176/1000, Loss: 1.4960 Epoch 177/1000, Loss: 1.4913 Epoch 178/1000, Loss: 1.4866 Epoch 179/1000, Loss: 1.4819 Epoch 180/1000, Loss: 1.4773 Epoch 181/1000, Loss: 1.4727 Epoch 182/1000, Loss: 1.4681 Epoch 183/1000, Loss: 1.4636 Epoch 184/1000, Loss: 1.4590 Epoch 185/1000, Loss: 1.4545 Epoch 186/1000, Loss: 1.4500 Epoch 187/1000, Loss: 1.4455 Epoch 188/1000, Loss: 1.4411 Epoch 189/1000, Loss: 1.4366 Epoch 190/1000, Loss: 1.4322 Epoch 191/1000, Loss: 1.4278 Epoch 192/1000, Loss: 1.4235 Epoch 193/1000, Loss: 1.4191 Epoch 194/1000, Loss: 1.4148 Epoch 195/1000, Loss: 1.4105 Epoch 196/1000, Loss: 1.4062 Epoch 197/1000, Loss: 1.4019 Epoch 198/1000, Loss: 1.3977 Epoch 199/1000, Loss: 1.3935 Epoch 200/1000, Loss: 1.3893 Epoch 201/1000, Loss: 1.3851 Epoch 202/1000, Loss: 1.3809 Epoch 203/1000, Loss: 1.3768 Epoch 204/1000, Loss: 1.3726 Epoch 205/1000, Loss: 1.3685 Epoch 206/1000, Loss: 1.3644 Epoch 207/1000, Loss: 1.3604 Epoch 208/1000, Loss: 1.3563 Epoch 209/1000, Loss: 1.3523 Epoch 210/1000, Loss: 1.3483 Epoch 211/1000, Loss: 1.3443 Epoch 212/1000, Loss: 1.3403 Epoch 213/1000, Loss: 1.3363 Epoch 214/1000, Loss: 1.3324 Epoch 215/1000, Loss: 1.3284 Epoch 216/1000, Loss: 1.3245 Epoch 217/1000, Loss: 1.3206 Epoch 218/1000, Loss: 1.3168 Epoch 219/1000, Loss: 1.3129 Epoch 220/1000, Loss: 1.3091 Epoch 221/1000, Loss: 1.3052 Epoch 222/1000, Loss: 1.3014 Epoch 223/1000, Loss: 1.2976 Epoch 224/1000, Loss: 1.2939 Epoch 225/1000, Loss: 1.2901 Epoch 226/1000, Loss: 1.2863 Epoch 227/1000, Loss: 1.2826 Epoch 228/1000, Loss: 1.2789 Epoch 229/1000, Loss: 1.2752 Epoch 230/1000, Loss: 1.2715 Epoch 231/1000, Loss: 1.2679 Epoch 232/1000, Loss: 1.2642 Epoch 233/1000, Loss: 1.2606 Epoch 234/1000, Loss: 1.2569 Epoch 235/1000, Loss: 1.2533 Epoch 236/1000, Loss: 1.2497 Epoch 237/1000, Loss: 1.2462 Epoch 238/1000, Loss: 1.2426 Epoch 239/1000, Loss: 1.2391 Epoch 240/1000, Loss: 1.2355 Epoch 241/1000, Loss: 1.2320 Epoch 242/1000, Loss: 1.2285 Epoch 243/1000, Loss: 1.2250 Epoch 244/1000, Loss: 1.2215 Epoch 245/1000, Loss: 1.2180 Epoch 246/1000, Loss: 1.2146 Epoch 247/1000, Loss: 1.2111 Epoch 248/1000, Loss: 1.2077 Epoch 249/1000, Loss: 1.2043 Epoch 250/1000, Loss: 1.2009 Epoch 251/1000, Loss: 1.1975 Epoch 252/1000, Loss: 1.1941 Epoch 253/1000, Loss: 1.1907 Epoch 254/1000, Loss: 1.1874 Epoch 255/1000, Loss: 1.1840 Epoch 256/1000, Loss: 1.1807 Epoch 257/1000, Loss: 1.1773 Epoch 258/1000, Loss: 1.1740 Epoch 259/1000, Loss: 1.1707 Epoch 260/1000, Loss: 1.1674 Epoch 261/1000, Loss: 1.1641 Epoch 262/1000, Loss: 1.1609 Epoch 263/1000, Loss: 1.1576 Epoch 264/1000, Loss: 1.1544 Epoch 265/1000, Loss: 1.1511 Epoch 266/1000, Loss: 1.1479 Epoch 267/1000, Loss: 1.1447 Epoch 268/1000, Loss: 1.1415 Epoch 269/1000, Loss: 1.1383 Epoch 270/1000, Loss: 1.1351 Epoch 271/1000, Loss: 1.1319 Epoch 272/1000, Loss: 1.1287 Epoch 273/1000, Loss: 1.1256 Epoch 274/1000, Loss: 1.1224 Epoch 275/1000, Loss: 1.1193 Epoch 276/1000, Loss: 1.1162 Epoch 277/1000, Loss: 1.1131 Epoch 278/1000, Loss: 1.1100 Epoch 279/1000, Loss: 1.1069 Epoch 280/1000, Loss: 1.1038 Epoch 281/1000, Loss: 1.1007 Epoch 282/1000, Loss: 1.0977 Epoch 283/1000, Loss: 1.0946 Epoch 284/1000, Loss: 1.0916 Epoch 285/1000, Loss: 1.0885 Epoch 286/1000, Loss: 1.0855 Epoch 287/1000, Loss: 1.0825 Epoch 288/1000, Loss: 1.0795 Epoch 289/1000, Loss: 1.0765 Epoch 290/1000, Loss: 1.0735 Epoch 291/1000, Loss: 1.0706 Epoch 292/1000, Loss: 1.0676 Epoch 293/1000, Loss: 1.0647 Epoch 294/1000, Loss: 1.0617 Epoch 295/1000, Loss: 1.0588 Epoch 296/1000, Loss: 1.0559 Epoch 297/1000, Loss: 1.0530 Epoch 298/1000, Loss: 1.0501 Epoch 299/1000, Loss: 1.0472 Epoch 300/1000, Loss: 1.0443 Epoch 301/1000, Loss: 1.0414 Epoch 302/1000, Loss: 1.0386 Epoch 303/1000, Loss: 1.0357 Epoch 304/1000, Loss: 1.0329 Epoch 305/1000, Loss: 1.0300 Epoch 306/1000, Loss: 1.0272 Epoch 307/1000, Loss: 1.0244 Epoch 308/1000, Loss: 1.0216 Epoch 309/1000, Loss: 1.0188 Epoch 310/1000, Loss: 1.0160 Epoch 311/1000, Loss: 1.0132 Epoch 312/1000, Loss: 1.0105 Epoch 313/1000, Loss: 1.0077 Epoch 314/1000, Loss: 1.0050 Epoch 315/1000, Loss: 1.0022 Epoch 316/1000, Loss: 0.9995 Epoch 317/1000, Loss: 0.9968 Epoch 318/1000, Loss: 0.9941 Epoch 319/1000, Loss: 0.9914 Epoch 320/1000, Loss: 0.9887 Epoch 321/1000, Loss: 0.9860 Epoch 322/1000, Loss: 0.9833 Epoch 323/1000, Loss: 0.9806 Epoch 324/1000, Loss: 0.9780 Epoch 325/1000, Loss: 0.9753 Epoch 326/1000, Loss: 0.9727 Epoch 327/1000, Loss: 0.9700 Epoch 328/1000, Loss: 0.9674 Epoch 329/1000, Loss: 0.9648 Epoch 330/1000, Loss: 0.9622 Epoch 331/1000, Loss: 0.9596 Epoch 332/1000, Loss: 0.9570 Epoch 333/1000, Loss: 0.9544 Epoch 334/1000, Loss: 0.9518 Epoch 335/1000, Loss: 0.9493 Epoch 336/1000, Loss: 0.9467 Epoch 337/1000, Loss: 0.9442 Epoch 338/1000, Loss: 0.9416 Epoch 339/1000, Loss: 0.9391 Epoch 340/1000, Loss: 0.9366 Epoch 341/1000, Loss: 0.9340 Epoch 342/1000, Loss: 0.9315 Epoch 343/1000, Loss: 0.9290 Epoch 344/1000, Loss: 0.9265 Epoch 345/1000, Loss: 0.9240 Epoch 346/1000, Loss: 0.9215 Epoch 347/1000, Loss: 0.9191 Epoch 348/1000, Loss: 0.9166 Epoch 349/1000, Loss: 0.9141 Epoch 350/1000, Loss: 0.9117 Epoch 351/1000, Loss: 0.9093 Epoch 352/1000, Loss: 0.9068 Epoch 353/1000, Loss: 0.9044 Epoch 354/1000, Loss: 0.9020 Epoch 355/1000, Loss: 0.8995 Epoch 356/1000, Loss: 0.8971 Epoch 357/1000, Loss: 0.8947 Epoch 358/1000, Loss: 0.8923 Epoch 359/1000, Loss: 0.8900 Epoch 360/1000, Loss: 0.8876 Epoch 361/1000, Loss: 0.8852 Epoch 362/1000, Loss: 0.8828 Epoch 363/1000, Loss: 0.8805 Epoch 364/1000, Loss: 0.8781 Epoch 365/1000, Loss: 0.8758 Epoch 366/1000, Loss: 0.8735 Epoch 367/1000, Loss: 0.8711 Epoch 368/1000, Loss: 0.8688 Epoch 369/1000, Loss: 0.8665 Epoch 370/1000, Loss: 0.8642 Epoch 371/1000, Loss: 0.8619 Epoch 372/1000, Loss: 0.8596 Epoch 373/1000, Loss: 0.8573 Epoch 374/1000, Loss: 0.8550 Epoch 375/1000, Loss: 0.8527 Epoch 376/1000, Loss: 0.8505 Epoch 377/1000, Loss: 0.8482 Epoch 378/1000, Loss: 0.8460 Epoch 379/1000, Loss: 0.8437 Epoch 380/1000, Loss: 0.8415 Epoch 381/1000, Loss: 0.8392 Epoch 382/1000, Loss: 0.8370 Epoch 383/1000, Loss: 0.8348 Epoch 384/1000, Loss: 0.8326 Epoch 385/1000, Loss: 0.8304 Epoch 386/1000, Loss: 0.8282 Epoch 387/1000, Loss: 0.8260 Epoch 388/1000, Loss: 0.8238 Epoch 389/1000, Loss: 0.8216 Epoch 390/1000, Loss: 0.8195 Epoch 391/1000, Loss: 0.8173 Epoch 392/1000, Loss: 0.8151 Epoch 393/1000, Loss: 0.8130 Epoch 394/1000, Loss: 0.8109 Epoch 395/1000, Loss: 0.8087 Epoch 396/1000, Loss: 0.8066 Epoch 397/1000, Loss: 0.8045 Epoch 398/1000, Loss: 0.8024 Epoch 399/1000, Loss: 0.8003 Epoch 400/1000, Loss: 0.7982 Epoch 401/1000, Loss: 0.7961 Epoch 402/1000, Loss: 0.7940 Epoch 403/1000, Loss: 0.7919 Epoch 404/1000, Loss: 0.7899 Epoch 405/1000, Loss: 0.7878 Epoch 406/1000, Loss: 0.7858 Epoch 407/1000, Loss: 0.7837 Epoch 408/1000, Loss: 0.7817 Epoch 409/1000, Loss: 0.7796 Epoch 410/1000, Loss: 0.7776 Epoch 411/1000, Loss: 0.7756 Epoch 412/1000, Loss: 0.7736 Epoch 413/1000, Loss: 0.7716 Epoch 414/1000, Loss: 0.7696 Epoch 415/1000, Loss: 0.7676 Epoch 416/1000, Loss: 0.7656 Epoch 417/1000, Loss: 0.7637 Epoch 418/1000, Loss: 0.7617 Epoch 419/1000, Loss: 0.7597 Epoch 420/1000, Loss: 0.7578 Epoch 421/1000, Loss: 0.7558 Epoch 422/1000, Loss: 0.7539 Epoch 423/1000, Loss: 0.7520 Epoch 424/1000, Loss: 0.7501 Epoch 425/1000, Loss: 0.7481 Epoch 426/1000, Loss: 0.7462 Epoch 427/1000, Loss: 0.7443 Epoch 428/1000, Loss: 0.7424 Epoch 429/1000, Loss: 0.7405 Epoch 430/1000, Loss: 0.7387 Epoch 431/1000, Loss: 0.7368 Epoch 432/1000, Loss: 0.7349 Epoch 433/1000, Loss: 0.7331 Epoch 434/1000, Loss: 0.7312 Epoch 435/1000, Loss: 0.7294 Epoch 436/1000, Loss: 0.7275 Epoch 437/1000, Loss: 0.7257 Epoch 438/1000, Loss: 0.7239 Epoch 439/1000, Loss: 0.7220 Epoch 440/1000, Loss: 0.7202 Epoch 441/1000, Loss: 0.7184 Epoch 442/1000, Loss: 0.7166 Epoch 443/1000, Loss: 0.7148 Epoch 444/1000, Loss: 0.7131 Epoch 445/1000, Loss: 0.7113 Epoch 446/1000, Loss: 0.7095 Epoch 447/1000, Loss: 0.7077 Epoch 448/1000, Loss: 0.7060 Epoch 449/1000, Loss: 0.7042 Epoch 450/1000, Loss: 0.7025 Epoch 451/1000, Loss: 0.7007 Epoch 452/1000, Loss: 0.6990 Epoch 453/1000, Loss: 0.6973 Epoch 454/1000, Loss: 0.6956 Epoch 455/1000, Loss: 0.6938 Epoch 456/1000, Loss: 0.6921 Epoch 457/1000, Loss: 0.6904 Epoch 458/1000, Loss: 0.6887 Epoch 459/1000, Loss: 0.6870 Epoch 460/1000, Loss: 0.6854 Epoch 461/1000, Loss: 0.6837 Epoch 462/1000, Loss: 0.6820 Epoch 463/1000, Loss: 0.6804 Epoch 464/1000, Loss: 0.6787 Epoch 465/1000, Loss: 0.6770 Epoch 466/1000, Loss: 0.6754 Epoch 467/1000, Loss: 0.6738 Epoch 468/1000, Loss: 0.6721 Epoch 469/1000, Loss: 0.6705 Epoch 470/1000, Loss: 0.6689 Epoch 471/1000, Loss: 0.6672 Epoch 472/1000, Loss: 0.6656 Epoch 473/1000, Loss: 0.6640 Epoch 474/1000, Loss: 0.6624 Epoch 475/1000, Loss: 0.6608 Epoch 476/1000, Loss: 0.6592 Epoch 477/1000, Loss: 0.6577 Epoch 478/1000, Loss: 0.6561 Epoch 479/1000, Loss: 0.6545 Epoch 480/1000, Loss: 0.6529 Epoch 481/1000, Loss: 0.6514 Epoch 482/1000, Loss: 0.6498 Epoch 483/1000, Loss: 0.6483 Epoch 484/1000, Loss: 0.6467 Epoch 485/1000, Loss: 0.6452 Epoch 486/1000, Loss: 0.6437 Epoch 487/1000, Loss: 0.6421 Epoch 488/1000, Loss: 0.6406 Epoch 489/1000, Loss: 0.6391 Epoch 490/1000, Loss: 0.6376 Epoch 491/1000, Loss: 0.6361 Epoch 492/1000, Loss: 0.6346 Epoch 493/1000, Loss: 0.6331 Epoch 494/1000, Loss: 0.6316 Epoch 495/1000, Loss: 0.6301 Epoch 496/1000, Loss: 0.6286 Epoch 497/1000, Loss: 0.6271 Epoch 498/1000, Loss: 0.6257 Epoch 499/1000, Loss: 0.6242 Epoch 500/1000, Loss: 0.6227 Epoch 501/1000, Loss: 0.6213 Epoch 502/1000, Loss: 0.6198 Epoch 503/1000, Loss: 0.6184 Epoch 504/1000, Loss: 0.6170 Epoch 505/1000, Loss: 0.6155 Epoch 506/1000, Loss: 0.6141 Epoch 507/1000, Loss: 0.6127 Epoch 508/1000, Loss: 0.6112 Epoch 509/1000, Loss: 0.6098 Epoch 510/1000, Loss: 0.6084 Epoch 511/1000, Loss: 0.6070 Epoch 512/1000, Loss: 0.6056 Epoch 513/1000, Loss: 0.6042 Epoch 514/1000, Loss: 0.6028 Epoch 515/1000, Loss: 0.6014 Epoch 516/1000, Loss: 0.6000 Epoch 517/1000, Loss: 0.5987 Epoch 518/1000, Loss: 0.5973 Epoch 519/1000, Loss: 0.5959 Epoch 520/1000, Loss: 0.5945 Epoch 521/1000, Loss: 0.5932 Epoch 522/1000, Loss: 0.5918 Epoch 523/1000, Loss: 0.5905 Epoch 524/1000, Loss: 0.5891 Epoch 525/1000, Loss: 0.5878 Epoch 526/1000, Loss: 0.5864 Epoch 527/1000, Loss: 0.5851 Epoch 528/1000, Loss: 0.5838 Epoch 529/1000, Loss: 0.5825 Epoch 530/1000, Loss: 0.5811 Epoch 531/1000, Loss: 0.5798 Epoch 532/1000, Loss: 0.5785 Epoch 533/1000, Loss: 0.5772 Epoch 534/1000, Loss: 0.5759 Epoch 535/1000, Loss: 0.5746 Epoch 536/1000, Loss: 0.5733 Epoch 537/1000, Loss: 0.5720 Epoch 538/1000, Loss: 0.5707 Epoch 539/1000, Loss: 0.5694 Epoch 540/1000, Loss: 0.5681 Epoch 541/1000, Loss: 0.5669 Epoch 542/1000, Loss: 0.5656 Epoch 543/1000, Loss: 0.5643 Epoch 544/1000, Loss: 0.5631 Epoch 545/1000, Loss: 0.5618 Epoch 546/1000, Loss: 0.5605 Epoch 547/1000, Loss: 0.5593 Epoch 548/1000, Loss: 0.5580 Epoch 549/1000, Loss: 0.5568 Epoch 550/1000, Loss: 0.5556 Epoch 551/1000, Loss: 0.5543 Epoch 552/1000, Loss: 0.5531 Epoch 553/1000, Loss: 0.5519 Epoch 554/1000, Loss: 0.5506 Epoch 555/1000, Loss: 0.5494 Epoch 556/1000, Loss: 0.5482 Epoch 557/1000, Loss: 0.5470 Epoch 558/1000, Loss: 0.5458 Epoch 559/1000, Loss: 0.5446 Epoch 560/1000, Loss: 0.5434 Epoch 561/1000, Loss: 0.5422 Epoch 562/1000, Loss: 0.5410 Epoch 563/1000, Loss: 0.5398 Epoch 564/1000, Loss: 0.5386 Epoch 565/1000, Loss: 0.5374 Epoch 566/1000, Loss: 0.5363 Epoch 567/1000, Loss: 0.5351 Epoch 568/1000, Loss: 0.5339 Epoch 569/1000, Loss: 0.5327 Epoch 570/1000, Loss: 0.5316 Epoch 571/1000, Loss: 0.5304 Epoch 572/1000, Loss: 0.5293 Epoch 573/1000, Loss: 0.5281 Epoch 574/1000, Loss: 0.5270 Epoch 575/1000, Loss: 0.5258 Epoch 576/1000, Loss: 0.5247 Epoch 577/1000, Loss: 0.5235 Epoch 578/1000, Loss: 0.5224 Epoch 579/1000, Loss: 0.5213 Epoch 580/1000, Loss: 0.5201 Epoch 581/1000, Loss: 0.5190 Epoch 582/1000, Loss: 0.5179 Epoch 583/1000, Loss: 0.5168 Epoch 584/1000, Loss: 0.5157 Epoch 585/1000, Loss: 0.5146 Epoch 586/1000, Loss: 0.5134 Epoch 587/1000, Loss: 0.5123 Epoch 588/1000, Loss: 0.5112 Epoch 589/1000, Loss: 0.5101 Epoch 590/1000, Loss: 0.5091 Epoch 591/1000, Loss: 0.5080 Epoch 592/1000, Loss: 0.5069 Epoch 593/1000, Loss: 0.5058 Epoch 594/1000, Loss: 0.5047 Epoch 595/1000, Loss: 0.5036 Epoch 596/1000, Loss: 0.5026 Epoch 597/1000, Loss: 0.5015 Epoch 598/1000, Loss: 0.5004 Epoch 599/1000, Loss: 0.4994 Epoch 600/1000, Loss: 0.4983 Epoch 601/1000, Loss: 0.4973 Epoch 602/1000, Loss: 0.4962 Epoch 603/1000, Loss: 0.4952 Epoch 604/1000, Loss: 0.4941 Epoch 605/1000, Loss: 0.4931 Epoch 606/1000, Loss: 0.4920 Epoch 607/1000, Loss: 0.4910 Epoch 608/1000, Loss: 0.4900 Epoch 609/1000, Loss: 0.4890 Epoch 610/1000, Loss: 0.4879 Epoch 611/1000, Loss: 0.4869 Epoch 612/1000, Loss: 0.4859 Epoch 613/1000, Loss: 0.4849 Epoch 614/1000, Loss: 0.4839 Epoch 615/1000, Loss: 0.4829 Epoch 616/1000, Loss: 0.4819 Epoch 617/1000, Loss: 0.4809 Epoch 618/1000, Loss: 0.4799 Epoch 619/1000, Loss: 0.4789 Epoch 620/1000, Loss: 0.4779 Epoch 621/1000, Loss: 0.4769 Epoch 622/1000, Loss: 0.4759 Epoch 623/1000, Loss: 0.4749 Epoch 624/1000, Loss: 0.4740 Epoch 625/1000, Loss: 0.4730 Epoch 626/1000, Loss: 0.4720 Epoch 627/1000, Loss: 0.4710 Epoch 628/1000, Loss: 0.4701 Epoch 629/1000, Loss: 0.4691 Epoch 630/1000, Loss: 0.4682 Epoch 631/1000, Loss: 0.4672 Epoch 632/1000, Loss: 0.4662 Epoch 633/1000, Loss: 0.4653 Epoch 634/1000, Loss: 0.4644 Epoch 635/1000, Loss: 0.4634 Epoch 636/1000, Loss: 0.4625 Epoch 637/1000, Loss: 0.4615 Epoch 638/1000, Loss: 0.4606 Epoch 639/1000, Loss: 0.4597 Epoch 640/1000, Loss: 0.4587 Epoch 641/1000, Loss: 0.4578 Epoch 642/1000, Loss: 0.4569 Epoch 643/1000, Loss: 0.4560 Epoch 644/1000, Loss: 0.4551 Epoch 645/1000, Loss: 0.4542 Epoch 646/1000, Loss: 0.4533 Epoch 647/1000, Loss: 0.4523 Epoch 648/1000, Loss: 0.4514 Epoch 649/1000, Loss: 0.4505 Epoch 650/1000, Loss: 0.4496 Epoch 651/1000, Loss: 0.4488 Epoch 652/1000, Loss: 0.4479 Epoch 653/1000, Loss: 0.4470 Epoch 654/1000, Loss: 0.4461 Epoch 655/1000, Loss: 0.4452 Epoch 656/1000, Loss: 0.4443 Epoch 657/1000, Loss: 0.4435 Epoch 658/1000, Loss: 0.4426 Epoch 659/1000, Loss: 0.4417 Epoch 660/1000, Loss: 0.4409 Epoch 661/1000, Loss: 0.4400 Epoch 662/1000, Loss: 0.4391 Epoch 663/1000, Loss: 0.4383 Epoch 664/1000, Loss: 0.4374 Epoch 665/1000, Loss: 0.4366 Epoch 666/1000, Loss: 0.4357 Epoch 667/1000, Loss: 0.4349 Epoch 668/1000, Loss: 0.4340 Epoch 669/1000, Loss: 0.4332 Epoch 670/1000, Loss: 0.4323 Epoch 671/1000, Loss: 0.4315 Epoch 672/1000, Loss: 0.4307 Epoch 673/1000, Loss: 0.4298 Epoch 674/1000, Loss: 0.4290 Epoch 675/1000, Loss: 0.4282 Epoch 676/1000, Loss: 0.4274 Epoch 677/1000, Loss: 0.4266 Epoch 678/1000, Loss: 0.4257 Epoch 679/1000, Loss: 0.4249 Epoch 680/1000, Loss: 0.4241 Epoch 681/1000, Loss: 0.4233 Epoch 682/1000, Loss: 0.4225 Epoch 683/1000, Loss: 0.4217 Epoch 684/1000, Loss: 0.4209 Epoch 685/1000, Loss: 0.4201 Epoch 686/1000, Loss: 0.4193 Epoch 687/1000, Loss: 0.4185 Epoch 688/1000, Loss: 0.4177 Epoch 689/1000, Loss: 0.4170 Epoch 690/1000, Loss: 0.4162 Epoch 691/1000, Loss: 0.4154 Epoch 692/1000, Loss: 0.4146 Epoch 693/1000, Loss: 0.4138 Epoch 694/1000, Loss: 0.4131 Epoch 695/1000, Loss: 0.4123 Epoch 696/1000, Loss: 0.4115 Epoch 697/1000, Loss: 0.4108 Epoch 698/1000, Loss: 0.4100 Epoch 699/1000, Loss: 0.4092 Epoch 700/1000, Loss: 0.4085 Epoch 701/1000, Loss: 0.4077 Epoch 702/1000, Loss: 0.4070 Epoch 703/1000, Loss: 0.4062 Epoch 704/1000, Loss: 0.4055 Epoch 705/1000, Loss: 0.4047 Epoch 706/1000, Loss: 0.4040 Epoch 707/1000, Loss: 0.4033 Epoch 708/1000, Loss: 0.4025 Epoch 709/1000, Loss: 0.4018 Epoch 710/1000, Loss: 0.4011 Epoch 711/1000, Loss: 0.4003 Epoch 712/1000, Loss: 0.3996 Epoch 713/1000, Loss: 0.3989 Epoch 714/1000, Loss: 0.3982 Epoch 715/1000, Loss: 0.3974 Epoch 716/1000, Loss: 0.3967 Epoch 717/1000, Loss: 0.3960 Epoch 718/1000, Loss: 0.3953 Epoch 719/1000, Loss: 0.3946 Epoch 720/1000, Loss: 0.3939 Epoch 721/1000, Loss: 0.3932 Epoch 722/1000, Loss: 0.3925 Epoch 723/1000, Loss: 0.3918 Epoch 724/1000, Loss: 0.3911 Epoch 725/1000, Loss: 0.3904 Epoch 726/1000, Loss: 0.3897 Epoch 727/1000, Loss: 0.3890 Epoch 728/1000, Loss: 0.3883 Epoch 729/1000, Loss: 0.3876 Epoch 730/1000, Loss: 0.3869 Epoch 731/1000, Loss: 0.3862 Epoch 732/1000, Loss: 0.3856 Epoch 733/1000, Loss: 0.3849 Epoch 734/1000, Loss: 0.3842 Epoch 735/1000, Loss: 0.3835 Epoch 736/1000, Loss: 0.3829 Epoch 737/1000, Loss: 0.3822 Epoch 738/1000, Loss: 0.3815 Epoch 739/1000, Loss: 0.3809 Epoch 740/1000, Loss: 0.3802 Epoch 741/1000, Loss: 0.3795 Epoch 742/1000, Loss: 0.3789 Epoch 743/1000, Loss: 0.3782 Epoch 744/1000, Loss: 0.3776 Epoch 745/1000, Loss: 0.3769 Epoch 746/1000, Loss: 0.3763 Epoch 747/1000, Loss: 0.3756 Epoch 748/1000, Loss: 0.3750 Epoch 749/1000, Loss: 0.3743 Epoch 750/1000, Loss: 0.3737 Epoch 751/1000, Loss: 0.3730 Epoch 752/1000, Loss: 0.3724 Epoch 753/1000, Loss: 0.3718 Epoch 754/1000, Loss: 0.3711 Epoch 755/1000, Loss: 0.3705 Epoch 756/1000, Loss: 0.3699 Epoch 757/1000, Loss: 0.3692 Epoch 758/1000, Loss: 0.3686 Epoch 759/1000, Loss: 0.3680 Epoch 760/1000, Loss: 0.3674 Epoch 761/1000, Loss: 0.3667 Epoch 762/1000, Loss: 0.3661 Epoch 763/1000, Loss: 0.3655 Epoch 764/1000, Loss: 0.3649 Epoch 765/1000, Loss: 0.3643 Epoch 766/1000, Loss: 0.3637 Epoch 767/1000, Loss: 0.3631 Epoch 768/1000, Loss: 0.3624 Epoch 769/1000, Loss: 0.3618 Epoch 770/1000, Loss: 0.3612 Epoch 771/1000, Loss: 0.3606 Epoch 772/1000, Loss: 0.3600 Epoch 773/1000, Loss: 0.3594 Epoch 774/1000, Loss: 0.3588 Epoch 775/1000, Loss: 0.3582 Epoch 776/1000, Loss: 0.3576 Epoch 777/1000, Loss: 0.3571 Epoch 778/1000, Loss: 0.3565 Epoch 779/1000, Loss: 0.3559 Epoch 780/1000, Loss: 0.3553 Epoch 781/1000, Loss: 0.3547 Epoch 782/1000, Loss: 0.3541 Epoch 783/1000, Loss: 0.3535 Epoch 784/1000, Loss: 0.3530 Epoch 785/1000, Loss: 0.3524 Epoch 786/1000, Loss: 0.3518 Epoch 787/1000, Loss: 0.3512 Epoch 788/1000, Loss: 0.3507 Epoch 789/1000, Loss: 0.3501 Epoch 790/1000, Loss: 0.3495 Epoch 791/1000, Loss: 0.3490 Epoch 792/1000, Loss: 0.3484 Epoch 793/1000, Loss: 0.3478 Epoch 794/1000, Loss: 0.3473 Epoch 795/1000, Loss: 0.3467 Epoch 796/1000, Loss: 0.3462 Epoch 797/1000, Loss: 0.3456 Epoch 798/1000, Loss: 0.3451 Epoch 799/1000, Loss: 0.3445 Epoch 800/1000, Loss: 0.3439 Epoch 801/1000, Loss: 0.3434 Epoch 802/1000, Loss: 0.3429 Epoch 803/1000, Loss: 0.3423 Epoch 804/1000, Loss: 0.3418 Epoch 805/1000, Loss: 0.3412 Epoch 806/1000, Loss: 0.3407 Epoch 807/1000, Loss: 0.3401 Epoch 808/1000, Loss: 0.3396 Epoch 809/1000, Loss: 0.3391 Epoch 810/1000, Loss: 0.3385 Epoch 811/1000, Loss: 0.3380 Epoch 812/1000, Loss: 0.3375 Epoch 813/1000, Loss: 0.3369 Epoch 814/1000, Loss: 0.3364 Epoch 815/1000, Loss: 0.3359 Epoch 816/1000, Loss: 0.3354 Epoch 817/1000, Loss: 0.3348 Epoch 818/1000, Loss: 0.3343 Epoch 819/1000, Loss: 0.3338 Epoch 820/1000, Loss: 0.3333 Epoch 821/1000, Loss: 0.3328 Epoch 822/1000, Loss: 0.3322 Epoch 823/1000, Loss: 0.3317 Epoch 824/1000, Loss: 0.3312 Epoch 825/1000, Loss: 0.3307 Epoch 826/1000, Loss: 0.3302 Epoch 827/1000, Loss: 0.3297 Epoch 828/1000, Loss: 0.3292 Epoch 829/1000, Loss: 0.3287 Epoch 830/1000, Loss: 0.3282 Epoch 831/1000, Loss: 0.3277 Epoch 832/1000, Loss: 0.3272 Epoch 833/1000, Loss: 0.3267 Epoch 834/1000, Loss: 0.3262 Epoch 835/1000, Loss: 0.3257 Epoch 836/1000, Loss: 0.3252 Epoch 837/1000, Loss: 0.3247 Epoch 838/1000, Loss: 0.3242 Epoch 839/1000, Loss: 0.3237 Epoch 840/1000, Loss: 0.3232 Epoch 841/1000, Loss: 0.3227 Epoch 842/1000, Loss: 0.3223 Epoch 843/1000, Loss: 0.3218 Epoch 844/1000, Loss: 0.3213 Epoch 845/1000, Loss: 0.3208 Epoch 846/1000, Loss: 0.3203 Epoch 847/1000, Loss: 0.3199 Epoch 848/1000, Loss: 0.3194 Epoch 849/1000, Loss: 0.3189 Epoch 850/1000, Loss: 0.3184 Epoch 851/1000, Loss: 0.3180 Epoch 852/1000, Loss: 0.3175 Epoch 853/1000, Loss: 0.3170 Epoch 854/1000, Loss: 0.3166 Epoch 855/1000, Loss: 0.3161 Epoch 856/1000, Loss: 0.3156 Epoch 857/1000, Loss: 0.3152 Epoch 858/1000, Loss: 0.3147 Epoch 859/1000, Loss: 0.3142 Epoch 860/1000, Loss: 0.3138 Epoch 861/1000, Loss: 0.3133 Epoch 862/1000, Loss: 0.3129 Epoch 863/1000, Loss: 0.3124 Epoch 864/1000, Loss: 0.3119 Epoch 865/1000, Loss: 0.3115 Epoch 866/1000, Loss: 0.3110 Epoch 867/1000, Loss: 0.3106 Epoch 868/1000, Loss: 0.3101 Epoch 869/1000, Loss: 0.3097 Epoch 870/1000, Loss: 0.3092 Epoch 871/1000, Loss: 0.3088 Epoch 872/1000, Loss: 0.3083 Epoch 873/1000, Loss: 0.3079 Epoch 874/1000, Loss: 0.3075 Epoch 875/1000, Loss: 0.3070 Epoch 876/1000, Loss: 0.3066 Epoch 877/1000, Loss: 0.3061 Epoch 878/1000, Loss: 0.3057 Epoch 879/1000, Loss: 0.3053 Epoch 880/1000, Loss: 0.3048 Epoch 881/1000, Loss: 0.3044 Epoch 882/1000, Loss: 0.3040 Epoch 883/1000, Loss: 0.3035 Epoch 884/1000, Loss: 0.3031 Epoch 885/1000, Loss: 0.3027 Epoch 886/1000, Loss: 0.3022 Epoch 887/1000, Loss: 0.3018 Epoch 888/1000, Loss: 0.3014 Epoch 889/1000, Loss: 0.3010 Epoch 890/1000, Loss: 0.3005 Epoch 891/1000, Loss: 0.3001 Epoch 892/1000, Loss: 0.2997 Epoch 893/1000, Loss: 0.2993 Epoch 894/1000, Loss: 0.2989 Epoch 895/1000, Loss: 0.2984 Epoch 896/1000, Loss: 0.2980 Epoch 897/1000, Loss: 0.2976 Epoch 898/1000, Loss: 0.2972 Epoch 899/1000, Loss: 0.2968 Epoch 900/1000, Loss: 0.2964 Epoch 901/1000, Loss: 0.2960 Epoch 902/1000, Loss: 0.2955 Epoch 903/1000, Loss: 0.2951 Epoch 904/1000, Loss: 0.2947 Epoch 905/1000, Loss: 0.2943 Epoch 906/1000, Loss: 0.2939 Epoch 907/1000, Loss: 0.2935 Epoch 908/1000, Loss: 0.2931 Epoch 909/1000, Loss: 0.2927 Epoch 910/1000, Loss: 0.2923 Epoch 911/1000, Loss: 0.2919 Epoch 912/1000, Loss: 0.2915 Epoch 913/1000, Loss: 0.2911 Epoch 914/1000, Loss: 0.2907 Epoch 915/1000, Loss: 0.2903 Epoch 916/1000, Loss: 0.2899 Epoch 917/1000, Loss: 0.2895 Epoch 918/1000, Loss: 0.2891 Epoch 919/1000, Loss: 0.2887 Epoch 920/1000, Loss: 0.2883 Epoch 921/1000, Loss: 0.2880 Epoch 922/1000, Loss: 0.2876 Epoch 923/1000, Loss: 0.2872 Epoch 924/1000, Loss: 0.2868 Epoch 925/1000, Loss: 0.2864 Epoch 926/1000, Loss: 0.2860 Epoch 927/1000, Loss: 0.2856 Epoch 928/1000, Loss: 0.2853 Epoch 929/1000, Loss: 0.2849 Epoch 930/1000, Loss: 0.2845 Epoch 931/1000, Loss: 0.2841 Epoch 932/1000, Loss: 0.2837 Epoch 933/1000, Loss: 0.2834 Epoch 934/1000, Loss: 0.2830 Epoch 935/1000, Loss: 0.2826 Epoch 936/1000, Loss: 0.2822 Epoch 937/1000, Loss: 0.2819 Epoch 938/1000, Loss: 0.2815 Epoch 939/1000, Loss: 0.2811 Epoch 940/1000, Loss: 0.2807 Epoch 941/1000, Loss: 0.2804 Epoch 942/1000, Loss: 0.2800 Epoch 943/1000, Loss: 0.2796 Epoch 944/1000, Loss: 0.2793 Epoch 945/1000, Loss: 0.2789 Epoch 946/1000, Loss: 0.2785 Epoch 947/1000, Loss: 0.2782 Epoch 948/1000, Loss: 0.2778 Epoch 949/1000, Loss: 0.2775 Epoch 950/1000, Loss: 0.2771 Epoch 951/1000, Loss: 0.2767 Epoch 952/1000, Loss: 0.2764 Epoch 953/1000, Loss: 0.2760 Epoch 954/1000, Loss: 0.2757 Epoch 955/1000, Loss: 0.2753 Epoch 956/1000, Loss: 0.2749 Epoch 957/1000, Loss: 0.2746 Epoch 958/1000, Loss: 0.2742 Epoch 959/1000, Loss: 0.2739 Epoch 960/1000, Loss: 0.2735 Epoch 961/1000, Loss: 0.2732 Epoch 962/1000, Loss: 0.2728 Epoch 963/1000, Loss: 0.2725 Epoch 964/1000, Loss: 0.2721 Epoch 965/1000, Loss: 0.2718 Epoch 966/1000, Loss: 0.2714 Epoch 967/1000, Loss: 0.2711 Epoch 968/1000, Loss: 0.2707 Epoch 969/1000, Loss: 0.2704 Epoch 970/1000, Loss: 0.2700 Epoch 971/1000, Loss: 0.2697 Epoch 972/1000, Loss: 0.2694 Epoch 973/1000, Loss: 0.2690 Epoch 974/1000, Loss: 0.2687 Epoch 975/1000, Loss: 0.2683 Epoch 976/1000, Loss: 0.2680 Epoch 977/1000, Loss: 0.2677 Epoch 978/1000, Loss: 0.2673 Epoch 979/1000, Loss: 0.2670 Epoch 980/1000, Loss: 0.2666 Epoch 981/1000, Loss: 0.2663 Epoch 982/1000, Loss: 0.2660 Epoch 983/1000, Loss: 0.2656 Epoch 984/1000, Loss: 0.2653 Epoch 985/1000, Loss: 0.2650 Epoch 986/1000, Loss: 0.2646 Epoch 987/1000, Loss: 0.2643 Epoch 988/1000, Loss: 0.2640 Epoch 989/1000, Loss: 0.2637 Epoch 990/1000, Loss: 0.2633 Epoch 991/1000, Loss: 0.2630 Epoch 992/1000, Loss: 0.2627 Epoch 993/1000, Loss: 0.2623 Epoch 994/1000, Loss: 0.2620 Epoch 995/1000, Loss: 0.2617 Epoch 996/1000, Loss: 0.2614 Epoch 997/1000, Loss: 0.2610 Epoch 998/1000, Loss: 0.2607 Epoch 999/1000, Loss: 0.2604 Epoch 1000/1000, Loss: 0.2601 MLP on MNIST with Sigmoid Accuracy: 0.9325, F1-score: 0.9314
Test Loss: 0.2811
rf = RandomForestClassifier(n_estimators=100, random_state=42)
rf.fit(train_X, train_y)
rf_preds = rf.predict(test_X)
rf_acc = accuracy_score(test_y, rf_preds)
rf_f1 = f1_score(test_y, rf_preds, average='macro')
rf_cm = confusion_matrix(test_y, rf_preds)
summary("Random Forest", rf_acc, rf_f1, rf_cm)
log_reg = LogisticRegression(max_iter=1000)
log_reg.fit(train_X, train_y)
log_preds = log_reg.predict(test_X)
log_acc = accuracy_score(test_y, log_preds)
log_f1 = f1_score(test_y, log_preds, average='macro')
log_cm = confusion_matrix(test_y, log_preds)
summary("Logistic Regression", log_acc, log_f1, log_cm)
Random Forest Accuracy: 0.9704, F1-score: 0.9702
Logistic Regression Accuracy: 0.9256, F1-score: 0.9245
print("MLP untrained with relu")
mlp_untrained_relu = MLP_relu().to(device)
visualize_tsne(mlp_untrained_relu, test_X, test_y, trained=False, device=device)
print("MLP untrained with sigmoid")
mlp_untrained_sigmoid = MLP_sigmoid().to(device)
visualize_tsne(mlp_untrained_sigmoid, test_X, test_y, trained=False, device=device)
print("MLP trained with relu")
visualize_tsne(mlp_relu, test_X, test_y, trained=True, device=device)
print("MLP trained with sigmoid")
visualize_tsne(mlp_sigmoid, test_X, test_y, trained=True, device=device)
MLP untrained with relu
MLP untrained with sigmoid
MLP trained with relu
MLP trained with sigmoid
t-SNE ComparisonΒΆ
t-SNE is a non linear dimensionality technique which is used to map high dimensional data to a 2d or 3d space for better visualisation. Observing the plots for the untrained MLPs (ReLU and Sigmoid), we see that there are no clusters and patterns. However, after training, there is clustering of each of the 10 classes, with very few outliers.
(f_train_X, f_train_y), (f_test_X, f_test_y) = fashion_mnist.load_data()
f_test_X = torch.from_numpy(f_test_X.reshape(-1, 784)).float() / 255
f_test_y = torch.from_numpy(f_test_y).to(torch.int64)
preds_f, acc_f, f1_f, cm_f, test_loss_f = evaluate_model(mlp_relu, f_test_X, f_test_y, device)
summary("MLP on Fashion-MNIST", acc_f, f1_f, cm_f)
print(f"Test Loss: {test_loss_f:.4f}")
visualize_tsne(mlp_relu, f_test_X, f_test_y, trained=True, device=device)
Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-labels-idx1-ubyte.gz 29515/29515 ββββββββββββββββββββ 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/train-images-idx3-ubyte.gz 26421880/26421880 ββββββββββββββββββββ 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-labels-idx1-ubyte.gz 5148/5148 ββββββββββββββββββββ 0s 0us/step Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/t10k-images-idx3-ubyte.gz 4422102/4422102 ββββββββββββββββββββ 0s 0us/step
/tmp/ipykernel_6110/160832730.py:2: UserWarning: The given NumPy array is not writable, and PyTorch does not support non-writable tensors. This means writing to this tensor will result in undefined behavior. You may want to copy the array to protect its data or make it writable before converting it to a tensor. This type of warning will be suppressed for the rest of this program. (Triggered internally at /pytorch/torch/csrc/utils/tensor_numpy.cpp:203.) f_test_X = torch.from_numpy(f_test_X.reshape(-1, 784)).float() / 255
MLP on Fashion-MNIST Accuracy: 0.0662, F1-score: 0.0457
Test Loss: 30.0740
print("\nTesting MLP (Sigmoid) on Fashion-MNIST...")
(f_train_X, f_train_y), (f_test_X, f_test_y) = fashion_mnist.load_data()
f_test_X = torch.from_numpy(f_test_X.reshape(-1, 784)).float() / 255
f_test_y = torch.from_numpy(f_test_y).to(torch.int64)
preds_f, acc_f, f1_f, cm_f, test_loss_f = evaluate_model(mlp_sigmoid, f_test_X, f_test_y, device)
summary("MLP on Fashion-MNIST", acc_f, f1_f, cm_f)
print(f"Test Loss: {test_loss_f:.4f}")
visualize_tsne(mlp_sigmoid, f_test_X, f_test_y, trained=True, device=device)
Testing MLP (Sigmoid) on Fashion-MNIST... MLP on Fashion-MNIST Accuracy: 0.1123, F1-score: 0.0623
Test Loss: 4.1060
Testing Trained MLP on Fashion MNISTΒΆ
The Fashion MNIST dataset is also similar to the MNIST dataset with it having images of size 28 x 28 and 10 output classes. When we use our trained MLPs on Fashion MNIST, we get very poor accuracy of around 11% with ReLU as activation and 6.6% wit Sigmoid as activation. We can get better results, by freezing the curren embeddings and adding 1 or 2 more layers and training it for a few epochs. From the t-SNE plot, there is no visible clustering of data which suggests no important features have been learnt. Moreover, both MLPs seem to be predict class 2 much more than any class.
Summary of Results for Question 1ΒΆ
| Model | Activation Function | Accuracy | F1 Score |
|---|---|---|---|
| MLP (ReLU) | ReLU | 0.9621 | 0.9618 |
| MLP (Sigmoid) | Sigmoid | 0.9325 | 0.9314 |
| Random Forest | N/A | 0.9704 | 0.9702 |
| Logistic Regression | Sigmoid | 0.9256 | 0.9245 |
Fashion MNISTΒΆ
| Model | Activation | Accuracy |
|---|---|---|
| MLP (ReLU) | ReLU | 0.1123 |
| MLP (Sigmoid) | Sigmoid | 0.0662 |
We observe that the best performing model is the random forest model, closely followed by the MLP with ReLU ,MLP with Sigmoid with Logistic Regression performing the worst.This could be explained by the fact that there are only 2 hidden layers and the number of neurons are less. We also observe Logistic Regression performs significantly worse than the others as it can only learn a linear boundary. Moreover, these models do not give any importance to the spaciality of images, and treat each pixel as individiual features. Thus, we would be able to achieve higher accuracy with CNNS as they will be able to capture local features like corners and edges. We observe many misclassifications, specifically with 2 and 7, which could be attributed to the fact that there is no inherent ordering of pixels in the above models
Comparison of t-SNE plots of MNIST and Fashion MNISTΒΆ
In the t-SNE plot of the 20 neuron layer for the Fashion MNIST dataset, there is no proper clustering visible, which suggests that the required features haven't been learnt and the model is unable to distinguish between the various classes. However, there is a clear and visible clustering in the t-SNE plot for the model trained on the MNIST dataset. We can conclude that the feature representations learnt on MNIST do not transfer well to the Fashion MNIST dataset. We can achieve higher accuracy by freezing the current model, and adding 1-2 more layers and training the model for a few epochs on the Fashion MNIST dataset
class CNN(nn.Module):
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(1, 32, 3)
self.pool = nn.MaxPool2d(2,2)
self.fc1 = nn.Linear(32*13*13, 128)
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = x.view(-1,1,28,28)
x = self.pool(F.relu(self.conv1(x)))
x = x.view(-1, 32*13*13)
x = F.relu(self.fc1(x))
return self.fc2(x)
cnn = CNN().to(device)
opt = optim.Adam(cnn.parameters(), lr=0.001)
loss_fn = nn.CrossEntropyLoss()
from torch.utils.data import TensorDataset, DataLoader
batch_size = 64
train_dataset = TensorDataset(train_X, train_y)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
train_losses = []
for epoch in range(100):
cnn.train()
total_loss = 0
for X_batch, y_batch in train_loader:
X_batch, y_batch = X_batch.to(device), y_batch.to(device)
opt.zero_grad()
out = cnn(X_batch)
loss = loss_fn(out, y_batch)
loss.backward()
opt.step()
total_loss += loss.item()
avg_loss = total_loss / len(train_loader)
train_losses.append(avg_loss)
print(f"Epoch {epoch+1}/100, Loss: {avg_loss:.4f}")
Epoch 1/100, Loss: 0.2182 Epoch 2/100, Loss: 0.0685 Epoch 3/100, Loss: 0.0466 Epoch 4/100, Loss: 0.0336 Epoch 5/100, Loss: 0.0247 Epoch 6/100, Loss: 0.0187 Epoch 7/100, Loss: 0.0150 Epoch 8/100, Loss: 0.0118 Epoch 9/100, Loss: 0.0089 Epoch 10/100, Loss: 0.0065 Epoch 11/100, Loss: 0.0058 Epoch 12/100, Loss: 0.0063 Epoch 13/100, Loss: 0.0041 Epoch 14/100, Loss: 0.0050 Epoch 15/100, Loss: 0.0029 Epoch 16/100, Loss: 0.0044 Epoch 17/100, Loss: 0.0047 Epoch 18/100, Loss: 0.0029 Epoch 19/100, Loss: 0.0009 Epoch 20/100, Loss: 0.0041 Epoch 21/100, Loss: 0.0007 Epoch 22/100, Loss: 0.0027 Epoch 23/100, Loss: 0.0034 Epoch 24/100, Loss: 0.0008 Epoch 25/100, Loss: 0.0002 Epoch 26/100, Loss: 0.0037 Epoch 27/100, Loss: 0.0024 Epoch 28/100, Loss: 0.0007 Epoch 29/100, Loss: 0.0010 Epoch 30/100, Loss: 0.0030 Epoch 31/100, Loss: 0.0018 Epoch 32/100, Loss: 0.0004 Epoch 33/100, Loss: 0.0010 Epoch 34/100, Loss: 0.0000 Epoch 35/100, Loss: 0.0000 Epoch 36/100, Loss: 0.0000
Epoch 37/100, Loss: 0.0000 Epoch 38/100, Loss: 0.0000 Epoch 39/100, Loss: 0.0000 Epoch 40/100, Loss: 0.0000 Epoch 41/100, Loss: 0.0000 Epoch 42/100, Loss: 0.0023 Epoch 43/100, Loss: 0.0046 Epoch 44/100, Loss: 0.0008 Epoch 45/100, Loss: 0.0002 Epoch 46/100, Loss: 0.0000 Epoch 47/100, Loss: 0.0000 Epoch 48/100, Loss: 0.0000 Epoch 49/100, Loss: 0.0000 Epoch 50/100, Loss: 0.0000 Epoch 51/100, Loss: 0.0000 Epoch 52/100, Loss: 0.0000 Epoch 53/100, Loss: 0.0000 Epoch 54/100, Loss: 0.0000 Epoch 55/100, Loss: 0.0000 Epoch 56/100, Loss: 0.0000 Epoch 57/100, Loss: 0.0000 Epoch 58/100, Loss: 0.0000 Epoch 59/100, Loss: 0.0000 Epoch 60/100, Loss: 0.0000 Epoch 61/100, Loss: 0.0000 Epoch 62/100, Loss: 0.0000 Epoch 63/100, Loss: 0.0000 Epoch 64/100, Loss: 0.0000 Epoch 65/100, Loss: 0.0000 Epoch 66/100, Loss: 0.0090 Epoch 67/100, Loss: 0.0007 Epoch 68/100, Loss: 0.0001 Epoch 69/100, Loss: 0.0000 Epoch 70/100, Loss: 0.0000 Epoch 71/100, Loss: 0.0000 Epoch 72/100, Loss: 0.0000 Epoch 73/100, Loss: 0.0000 Epoch 74/100, Loss: 0.0000 Epoch 75/100, Loss: 0.0000 Epoch 76/100, Loss: 0.0000 Epoch 77/100, Loss: 0.0000 Epoch 78/100, Loss: 0.0000 Epoch 79/100, Loss: 0.0000 Epoch 80/100, Loss: 0.0000 Epoch 81/100, Loss: 0.0000 Epoch 82/100, Loss: 0.0000 Epoch 83/100, Loss: 0.0000 Epoch 84/100, Loss: 0.0000 Epoch 85/100, Loss: 0.0000 Epoch 86/100, Loss: 0.0000 Epoch 87/100, Loss: 0.0000 Epoch 88/100, Loss: 0.0000 Epoch 89/100, Loss: 0.0000 Epoch 90/100, Loss: 0.0000 Epoch 91/100, Loss: 0.0000 Epoch 92/100, Loss: 0.0000 Epoch 93/100, Loss: 0.0000 Epoch 94/100, Loss: 0.0000 Epoch 95/100, Loss: 0.0000 Epoch 96/100, Loss: 0.0059 Epoch 97/100, Loss: 0.0020 Epoch 98/100, Loss: 0.0010 Epoch 99/100, Loss: 0.0010 Epoch 100/100, Loss: 0.0023
preds, acc_cnn, f1_cnn, cm_cnn, test_loss_cnn = evaluate_model(cnn, test_X, test_y, device)
params_cnn = sum(p.numel() for p in cnn.parameters())
start = time.time()
with torch.no_grad():
_ = cnn(test_X[:512].to(device))
t_cnn = time.time() - start
summary("Simple CNN", acc_cnn, f1_cnn, cm_cnn, train_losses)
print(f"Params: {params_cnn:,}, Inference time: {t_cnn:.4f}s, Test Loss: {test_loss_cnn:.4f}")
Simple CNN Accuracy: 0.9851, F1-score: 0.9850
Params: 693,962, Inference time: 0.0006s, Test Loss: 0.1167
batch_size = 128
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def preprocess_batch(Xb):
Xb = Xb.to(device)
Xb = Xb.repeat(1, 3, 1, 1)
Xb = F.interpolate(Xb, size=(224, 224), mode='bilinear', align_corners=False)
mean = torch.tensor([0.485, 0.456, 0.406], device=device).view(1, 3, 1, 1)
std = torch.tensor([0.229, 0.224, 0.225], device=device).view(1, 3, 1, 1)
Xb = (Xb - mean) / std
return Xb
def evaluate_model_in_batches(model, X, y, batch_size, device):
model.eval()
preds = []
true_labels = []
total_loss = 0
criterion = nn.CrossEntropyLoss()
# Create batches
for i in range(0, len(X), batch_size):
Xb = X[i:i + batch_size].to(device)
yb = y[i:i + batch_size].to(device)
Xb = preprocess_batch(Xb)
with torch.no_grad():
outputs = model(Xb)
loss = criterion(outputs, yb)
total_loss += loss.item() * len(Xb)
preds.append(outputs.argmax(dim=1).cpu().numpy())
true_labels.append(yb.cpu().numpy())
preds = np.concatenate(preds)
true_labels = np.concatenate(true_labels)
acc = (preds == true_labels).mean()
f1 = f1_score(true_labels, preds, average='weighted')
cm = confusion_matrix(true_labels, preds)
avg_loss = total_loss / len(X)
return preds, acc, f1, cm, avg_loss
Pretrained CNNS without fine-tuningΒΆ
mobilenet = models.mobilenet_v2(weights="IMAGENET1K_V1")
mobilenet.classifier[1] = nn.Linear(1280, 10)
mobilenet = mobilenet.to(device)
efficientnet = models.efficientnet_b0(weights="IMAGENET1K_V1")
efficientnet.classifier[1] = nn.Linear(1280, 10)
efficientnet = efficientnet.to(device)
print("Evaluating MobileNetV2...")
preds_mob, acc_mob, f1_mob, cm_mob, loss_mob = evaluate_model_in_batches(
mobilenet, test_X.reshape(-1,1,28,28), test_y, batch_size, device
)
summary("MobileNetV2", acc_mob, f1_mob, cm_mob)
params_mob = sum(p.numel() for p in mobilenet.parameters())
start = time.time()
with torch.no_grad():
_ = mobilenet(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_mob = time.time() - start
print(f"Params: {params_mob:,}, Inference time: {t_mob:.4f}s, Test Loss: {loss_mob:.4f}")
print("\nEvaluating EfficientNet-B0.")
preds_eff, acc_eff, f1_eff, cm_eff, loss_eff = evaluate_model_in_batches(
efficientnet, test_X.reshape(-1,1,28,28), test_y, batch_size, device
)
summary("EfficientNet-B0", acc_eff, f1_eff, cm_eff)
params_eff = sum(p.numel() for p in efficientnet.parameters())
start = time.time()
with torch.no_grad():
_ = efficientnet(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_eff = time.time() - start
print(f"Params: {params_eff:,}, Inference time: {t_eff:.4f}s, Test Loss: {loss_eff:.4f}")
Evaluating MobileNetV2... MobileNetV2 Accuracy: 0.0942, F1-score: 0.0218
Params: 2,236,682, Inference time: 0.0057s, Test Loss: 2.3812 Evaluating EfficientNet-B0. EfficientNet-B0 Accuracy: 0.1007, F1-score: 0.0776
Params: 4,020,358, Inference time: 0.0082s, Test Loss: 2.2830
Fine-Tuning the Pretrained CNNsΒΆ
mobilenet = models.mobilenet_v2(weights="IMAGENET1K_V1")
for p in mobilenet.parameters():
p.requires_grad = False
in_features = mobilenet.classifier[-1].in_features
mobilenet.classifier = nn.Sequential(
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
mobilenet = mobilenet.to(device)
opt_m = optim.Adam(filter(lambda p: p.requires_grad, mobilenet.parameters()), lr=1e-3)
loss_fn = nn.CrossEntropyLoss()
X_train = train_X.reshape(-1,1,28,28)
y_train = train_y
train_losses_m = []
print("\nTraining MobileNetV2 (fine-tuned with 2 layers)...")
for epoch in range(10):
mobilenet.train()
total_loss = 0
for i in range(0, len(X_train), batch_size):
Xb = X_train[i:i+batch_size].to(device)
yb = y_train[i:i+batch_size].to(device)
Xb = preprocess_batch(Xb)
opt_m.zero_grad()
out = mobilenet(Xb)
loss = loss_fn(out, yb)
loss.backward()
opt_m.step()
total_loss += loss.item() * len(Xb)
avg_loss = total_loss / len(X_train)
train_losses_m.append(avg_loss)
print(f"Epoch {epoch+1}/10 Loss: {avg_loss:.4f}")
preds_m, acc_m, f1_m, cm_m, test_loss_m = evaluate_model_in_batches(mobilenet, test_X.reshape(-1,1,28,28), test_y, batch_size, device)
summary("MobileNetV2 (fine-tuned)", acc_m, f1_m, cm_m, train_losses=train_losses_m)
params_m = sum(p.numel() for p in mobilenet.parameters())
start = time.time()
with torch.no_grad():
_ = mobilenet(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_m = time.time() - start
print(f"Params: {params_m:,}, Inference time: {t_m:.4f}s, Test Loss: {test_loss_m:.4f}")
efficient = models.efficientnet_b0(weights="IMAGENET1K_V1")
for p in efficient.parameters():
p.requires_grad = False
last_layer = efficient.classifier[-1]
if isinstance(last_layer, nn.Linear):
in_features = last_layer.in_features
else:
in_features = last_layer[1].in_features
efficient.classifier = nn.Sequential(
nn.Linear(in_features, 256),
nn.ReLU(),
nn.Linear(256, 128),
nn.ReLU(),
nn.Linear(128, 10)
)
efficient = efficient.to(device)
opt_e = optim.Adam(filter(lambda p: p.requires_grad, efficient.parameters()), lr=1e-3)
train_losses_e = []
for epoch in range(10):
efficient.train()
total_loss = 0
for i in range(0, len(X_train), batch_size):
Xb = X_train[i:i+batch_size].to(device)
yb = y_train[i:i+batch_size].to(device)
Xb = preprocess_batch(Xb)
opt_e.zero_grad()
out = efficient(Xb)
loss = loss_fn(out, yb)
loss.backward()
opt_e.step()
total_loss += loss.item() * len(Xb)
avg_loss = total_loss / len(X_train)
train_losses_e.append(avg_loss)
print(f"Epoch {epoch+1}/10 Loss: {avg_loss:.4f}")
preds_e, acc_e, f1_e, cm_e, test_loss_e = evaluate_model_in_batches(efficient, test_X.reshape(-1,1,28,28), test_y, batch_size, device)
summary("EfficientNet_B0 (fine-tuned)", acc_e, f1_e, cm_e, train_losses=train_losses_e)
params_e = sum(p.numel() for p in efficient.parameters())
start = time.time()
with torch.no_grad():
_ = efficient(preprocess_batch(test_X[:256].reshape(-1,1,28,28).to(device)))
t_e = time.time() - start
print(f"Params: {params_e:,}, Inference time: {t_e:.4f}s, Test Loss: {test_loss_e:.4f}")
Training MobileNetV2 (fine-tuned with 2 layers)... Epoch 1/10 Loss: 0.3012 Epoch 2/10 Loss: 0.1365 Epoch 3/10 Loss: 0.1097 Epoch 4/10 Loss: 0.0922 Epoch 5/10 Loss: 0.0826 Epoch 6/10 Loss: 0.0742 Epoch 7/10 Loss: 0.0709 Epoch 8/10 Loss: 0.0650 Epoch 9/10 Loss: 0.0561 Epoch 10/10 Loss: 0.0474 MobileNetV2 (fine-tuned) Accuracy: 0.9721, F1-score: 0.9721
Params: 2,585,994, Inference time: 0.0059s, Test Loss: 0.1047 Epoch 1/10 Loss: 0.3345 Epoch 2/10 Loss: 0.1747 Epoch 3/10 Loss: 0.1461 Epoch 4/10 Loss: 0.1220 Epoch 5/10 Loss: 0.1081 Epoch 6/10 Loss: 0.0986
Epoch 7/10 Loss: 0.0910 Epoch 8/10 Loss: 0.0814 Epoch 9/10 Loss: 0.0784 Epoch 10/10 Loss: 0.0728 EfficientNet_B0 (fine-tuned) Accuracy: 0.9730, F1-score: 0.9730
Params: 4,369,670, Inference time: 0.0084s, Test Loss: 0.0921
Summary of Question 2ΒΆ
| Model | Accuracy | F1 Score | Inference Time (s / 256 imgs) | Parameters |
|---|---|---|---|---|
| CNN | 0.9851 | 0.9850 | 0.0006 | 693962 |
| MobileNetV2 with fine tuning | 0.9721 | 0.9721 | 0.0059 | 2585994 |
| EfficientNet_B0 with finetuning | 0.9730 | 0.9730 | 0.0084 | 4369670 |
| MobileNetV2 (pretrained) | 0.0942 | 0.0218 | 0.0057 | 2236682 |
| EfficientNet_B0 (pretrained) | 0.1007 | 0.0776 | 0.0082 | 4020358 |
We observe that the CNN achieved the highest accuracy across all models, while also having the least inference time and parameters. The F1-Score was also highest for the CNN. Predicting with just the pretrained models gave us an accuracy which was worse than random guessing, confirmed by the confusion matrix. This could be because pretrained models were trained on the ImageNet dataset which primarily contains RGB images of objects and animals, which does not directly transfer to handwritten digit recognition.
We experimented by performing transfer learning by removing the last layer, and replacing it with 2 hidden layers and an output layer of 10 neurons. We then fine tuned the model by training it on the train dataset for 10 epochs. We have used MobileNetV2 and EfficicentNet_B0. Although these are supposed to be very powerful pretrained models with high number of parametrs, they still do not perform as well on the MNIST dataset as they are trained for recognizing features from complex RGB images. These features may not translate into the simple and limited features in the MNIST dataset. However, a simple CNN can learn these features from scratch very easily leading to better accuracy, while having much lesser number of parameters.
Overall, pretrained models still achieve higher accuracy than that of the other models like MLP, Random forest and Logistic Regression after finetuning. We could get higher accuracy with more finetuning and trainng for more number of epochs. We can also conclude that inference time is directly proportional to the number of parameters.